Coder Social home page Coder Social logo

cs744-mpi's Introduction

Marius Script

Usage

./mpic <pyfiles> -l <log level>

Example

Input

"""
See https://docs.dgl.ai/tutorials/blitz/3_message_passing.html
"""


class BasicLayer(mpi.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(mpi.Module, self).__init__(input_dim, output_dim)
        self.linear = mpi.Linear(input_dim * 2, output_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.linear.reset_parameters()

    def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor:
        with graph.local_scope():
            graph.ndata["h"] = h
            graph.update_all(
                message_func=mpi.copy_u("h", "m"), reduce_func=mpi.mean("m", "h_N")
            )
            h_N = graph.ndata["h_N"]
            h_total = mpi.cat(h, h_N, dim=1)
            return self.linear(h_total)

Outputs

basic_layer.h

//
// Autogenerated file!
//

#pragma once

#include "common/datatypes.h"
#include "configuration/options.h"
#include "configuration/config.h"
#include "data/graph.h"
#include "nn/initialization.h"
#include "gnn_layer.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include "torch/torch.h"
#pragma GCC diagnostic pop

struct BasicLayerOptions : GNNLayerOptions {
};

class BasicLayer : public GNNLayer {
   public:
    shared_ptr<BasicLayerOptions> options_;
    torch::Tensor linear_;
    int _mpic_linear_input_dim_;
    int _mpic_linear_output_dim_;

    BasicLayer(shared_ptr<LayerConfig> layer_config, torch::Device device);

    void reset() override;

    torch::Tensor forward(torch::Tensor inputs, DENSEGraph dense_graph, bool train = true) override;
};

basic_layer.cpp

//
// Autogenerated file!
//

#include "nn/layers/gnn/basic_layer.h"

#include "nn/layers/gnn/layer_helpers.h"
#include "reporting/logger.h"
#include "nn/initialization.h"

// XXX: GCC warns on always_inline
#pragma GCC diagnostic ignored "-Wattributes"

// XXX: always_inline to mirror original code (unsure if more efficient)
#define FUNCGEN __attribute__((always_inline))

namespace {

struct SumFunc {
    static constexpr bool useNumNbrs = false;

    FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
        return segmented_sum_with_offsets(embeds, offsets);
    }
};

struct MaxFunc {
    static constexpr bool useNumNbrs = false;

    FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
        return segmented_max_with_offsets(embeds, offsets);
    }
};

struct MeanFunc {
    static constexpr bool useNumNbrs = true;

    FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
        return segmented_sum_with_offsets(embeds, offsets);
    }

    FUNCGEN static torch::Tensor applyNumNbrs(torch::Tensor const& a_i, torch::Tensor const& num_nbrs) {
        torch::Tensor denominator = torch::where(torch::not_equal(
                    num_nbrs, 0), num_nbrs, 1).to(a_i.dtype()).unsqueeze(-1);
        return a_i / denominator;
    }
};

template <class ReduceFunc>
FUNCGEN torch::Tensor update_all(DENSEGraph& dense_graph, torch::Tensor const& u) {
    constexpr bool useNumNbrs = ReduceFunc::useNumNbrs;
    torch::Tensor a_i;
    [[maybe_unused]] torch::Tensor total_num_neighbors;

    if (dense_graph.out_neighbors_mapping_.defined()) {
        Indices outgoing_neighbors = dense_graph.getNeighborIDs(false, false);
        Indices outgoing_neighbor_offsets = dense_graph.getNeighborOffsets(false);
        torch::Tensor outgoing_num = dense_graph.getNumNeighbors(false);

        torch::Tensor outgoing_embeddings = u.index_select(0, outgoing_neighbors);
        a_i = ReduceFunc::segmented_reduce(outgoing_embeddings, outgoing_neighbor_offsets);

        // often, aggregation functions require the number of neighbors
        if constexpr(useNumNbrs) {
            total_num_neighbors = outgoing_num;
        }
    }

    if (dense_graph.in_neighbors_mapping_.defined()) {
        Indices incoming_neighbors = dense_graph.getNeighborIDs(true, false);
        Indices incoming_neighbor_offsets = dense_graph.getNeighborOffsets(true);
        torch::Tensor incoming_num = dense_graph.getNumNeighbors(true);

        torch::Tensor incoming_embeddings = u.index_select(0, incoming_neighbors);

        if (a_i.defined()) {
            a_i = a_i + segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets);
        } else {
            a_i = segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets);
        }

        // often, aggregation functions require the number of neighbors
        if constexpr(useNumNbrs) {
            if (total_num_neighbors.defined()) {
                total_num_neighbors = total_num_neighbors + incoming_num;
            } else {
                total_num_neighbors = incoming_num;
            }
        }
    }

    if constexpr(useNumNbrs) {
        return ReduceFunc::applyNumNbrs(a_i, total_num_neighbors);
    } else {
        return a_i;
    }
}

} // anonymous namespace

BasicLayer::BasicLayer(shared_ptr<LayerConfig> layer_config, torch::Device device) {
    config_ = layer_config;
    options_ = std::dynamic_pointer_cast<BasicLayerOptions>(config_->options);
    input_dim_ = config_->input_dim;
    output_dim_ = config_->output_dim;
    device_ = device;

    _mpic_linear_input_dim_ = input_dim_ * 2;
    _mpic_linear_output_dim_ = output_dim_;
    reset();
}

void BasicLayer::reset() {
    [[maybe_unused]] auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device_);

    linear_ = initialize_tensor(config_->init, {_mpic_linear_output_dim_, _mpic_linear_input_dim_}, tensor_options).set_requires_grad(true);;

    if (config_->bias) {
        init_bias();
    }
}

torch::Tensor BasicLayer::forward(torch::Tensor h, DENSEGraph graph, bool train) {
    torch::Tensor _mpic_ndata_h;
    torch::Tensor _mpic_ndata_h_N;
    torch::Tensor h_N;
    torch::Tensor h_total;

    _mpic_ndata_h = h;
    _mpic_ndata_h_N = update_all<MeanFunc>(graph, _mpic_ndata_h);;
    h_N = _mpic_ndata_h_N;
    h_total = torch::cat({h, h_N}, 1);
    return torch::matmul(linear_, h_total.transpose(0, -1)).transpose(0, -1);
}

cs744-mpi's People

Contributors

pao214 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.