Coder Social home page Coder Social logo

Comments (2)

Elnifio avatar Elnifio commented on June 25, 2024

To reproduce the issue, we run the following code snippet in GLT:

import graphlearn_torch as glt
import torch
import torch_geometric as pyg

# We consider a simplified version of GATConv here. 
# GATConv computes the destination nodes' embedding from both its embedding and its source nodes' embedding
# we simplify it here so that only an addition operation is performed
# between destination nodes' embedding and the source nodes' embedding. 
class DummyLayer(pyg.nn.MessagePassing):
    def __init__(self):
        super().__init__(aggr="add")
        
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        return x_i + x_j

# to simplify the discussion, we create a tree for generating the samples. 
u = torch.tensor([0, 1, 2, 3,  4,  5,  6,  7,  8,  9, 10, 11])
v = torch.tensor([8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13])
edge_index = torch.stack([u, v], dim=0)

# and to trace which node is involved in the computation,
# the features are assigned as 2^i for node i
x = torch.stack([2**torch.tensor(range(14)) for _ in range(4)], dim=1).to(torch.float32)

# GLT IGBH's dataloading code
dataset = glt.data.Dataset(edge_dir="in")
dataset.init_graph(edge_index, graph_mode="CPU")
dataset.init_node_features(node_feature_data=x, with_gpu=False)
loader = glt.loader.NeighborLoader(
    dataset,
    [1, 1],
    torch.tensor([12, 13]),
    batch_size=2
)
data = next(iter(loader))

x = data.x
edge_index = data.edge_index
# At start, one example edge index could be: 
# 2 -> 9 -> 12
# 5 -> 10 -> 13
# and the corresponding features could be:
# tensor([[4096, 4096, 4096, 4096], # node 12's feature
#         [8192, 8192, 8192, 8192], # node 13's feature
#         [ 512,  512,  512,  512], # node 9's feature
#         [1024, 1024, 1024, 1024], # node 10's feature
#         [   4,    4,    4,    4], # node 2's feature
#         [  32,   32,   32,   32]], dtype=torch.int32) # node 5's feature

h1 = DummyLayer()(x, edge_index)
# after we apply the first transformation, the features now are updated as:
# tensor([[4608, 4608, 4608, 4608], # node 12, h_1(12) = h_0(12) + h_0(9)
#         [9216, 9216, 9216, 9216], # node 13, h_1(13) = h_0(13) + h_0(10)
#         [ 516,  516,  516,  516], # node 9, h_1(9) = h_0(9) + h_0(2)
#         [1056, 1056, 1056, 1056], # node 10, h_1(10) = h_0(10) + h_0(5)
#         [   0,    0,    0,    0], # node 2, no update since there are no incoming edges
#         [   0,    0,    0,    0]], dtype=torch.int32) # node 5, no update since there are no incoming edges

h2 = DummyLayer()(h1, edge_index)
# after we apply the second transformation, the features now are updated as:
# tensor([[ 5124,  5124,  5124,  5124], # node 12, h_2(12) = h_1(12) + h_1(9)
#         [10272, 10272, 10272, 10272], # node 13, h_2(13) = h_1(13) + h_1(10)
#         [  516,   516,   516,   516],
#         [ 1056,  1056,  1056,  1056],
#         [    0,     0,     0,     0],
#         [    0,     0,     0,     0]], dtype=torch.int32)

# for node 12, the message passing path is:
# 9 -> 12
# 2 -> 9 (value updated with node 2's feature) -> 12
# thus, 9 is updated twice. 

# this is different from DGL's implementation
# in DGL, if we have a path 2 -> 9 -> 12, 
# then the value will be: (2**12 + 2**9 + 2**2)
# verified:
# (2**12 + 2**9 + (2**9 + 2**2)) == 5124

And we run the following code for DGL:

import dgl
import torch

class DummyLayer(torch.nn.Module):
    def __init__(self):
        # Takes in source and destination nodes, adds them according to the edges, and returns it
        super(DummyLayer, self).__init__()
        
    def forward(self, graph, hidden):
        graph = graph.local_var()
        graph.srcdata['temp_src'] = hidden
        graph.dstdata['temp_dst'] = hidden[:graph.number_of_dst_nodes()]
        graph.update_all(
            dgl.function.u_add_v('temp_src', 'temp_dst', 'temp'),
            dgl.function.sum("temp", "result")
        )
        new_hidden = graph.dstdata['result']
        return new_hidden

u = torch.tensor([0, 1, 2, 3,  4,  5,  6,  7,  8,  9, 10, 11])
v = torch.tensor([8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13])
g = dgl.graph((u, v))

g.ndata['feat'] = torch.stack(
    [2**g.nodes(), 2**g.nodes(), 2**g.nodes(), 2**g.nodes()], 
    dim=1).to(torch.float32) 

sampler = dgl.dataloading.NeighborSampler([1, 1])
loader = dgl.dataloading.DataLoader(g, torch.tensor([12, 13]), sampler, batch_size=2)
input_nodes, destination_nodes, blocks = next(iter(loader))

x = blocks[0].srcdata['feat']
# here, if we examine the edges in block 0, an example result could be: 
# 8->12
# 10->13
# 2->9
# 4->10
# only these destination nodes' information are being updated. 
# and we examine the features: 
# tensor([[4096, 4096, 4096, 4096], # node 12
#         [8192, 8192, 8192, 8192], # node 13
#         [ 512,  512,  512,  512], # node 9
#         [1024, 1024, 1024, 1024], # node 10
#         [ 256,  256,  256,  256], # node 8
#         [   4,    4,    4,    4], # node 2
#         [  16,   16,   16,   16]], dtype=torch.int32) # node 4

h1 = DummyLayer()(blocks[0], x)
# After applying the transformation along edges defined in block 0, we see the information is updated as:
# tensor([[4352, 4352, 4352, 4352], # node 12, as a sum of node 8 and node 12
#         [9216, 9216, 9216, 9216], # node 13, as a sum of node 10 and node 13
#         [ 516,  516,  516,  516], # node 9, as a sum of node 2 and node 9
#         [1040, 1040, 1040, 1040]], dtype=torch.int32) # node 10, as a sum of node 4 and node 10

# Now, if we examine block 1, we should see edges: 
# 9->12
# 10->13

h2 = DummyLayer()(blocks[1], h1)
# After applying the transformation along edges defined in block 1, we see the information is updated as:
# tensor([[ 4868,  4868,  4868,  4868], # final result for node 12
#         [10256, 10256, 10256, 10256]], dtype=torch.int32) # final result for node 13

# for node 12, the sampling is: 
# we start with [12]
# at first hop, we add 9: [12, 9]
# at second hop, we add: 
#     8, since 12 is in the collection of seed node (this feature is discussed in DGL thread)
#     2, since 9 is in the collection of seed node
# so we have [12, 9, 2, 8]
# and the message passing path that DGL takes is: 
# 2 -> 9 -> 12 <- 8
# so, the information that we get, should be all information that nodes 2, 8, 9, and 12 have:

# verified
# (2**9 + 2**12 + 2**8 + 2**2) == 4868

from graphlearn-for-pytorch.

LiSu avatar LiSu commented on June 25, 2024

Hi @Elnifio

Yes, this difference is caused by the different approaches of sampling and message passing in PyG and DGL.

from graphlearn-for-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.