Coder Social home page Coder Social logo

Is it working on batched data? about pygat HOT 11 CLOSED

diego999 avatar diego999 commented on August 24, 2024 1
Is it working on batched data?

from pygat.

Comments (11)

Cartus avatar Cartus commented on August 24, 2024 7

Hi Diego,

Thanks for your great work! May I ask if I can implement it in this way to support batch-wise training.
`def forward(self, input, adj):

    batch_size = input.size(0)
    h = torch.bmm(input, self.W.expand(batch_size, self.in_features, self.out_features))

    f_1 = torch.bmm(h, self.a1.expand(batch_size, self.out_features, 1))
    f_2 = torch.bmm(h, self.a2.expand(batch_size, self.out_features, 1))
    e = self.leakyrelu(f_1 + f_2.transpose(2,1))

    zero_vec = -9e15*torch.ones_like(e)
    attention = torch.where(adj > 0, e, zero_vec)
    attention = F.softmax(attention, dim=1)
    attention = F.dropout(attention, self.dropout, training=self.training)
    h_prime = torch.bmm(attention, h)

    if self.concat:
        return F.elu(h_prime)
    else:
        return h_prime`

Thank you!

from pygat.

Diego999 avatar Diego999 commented on August 24, 2024

Hi LeeJunHyun,

I didn't implement it to handle batched graphs. However, it can be done by using a block-diagonal adjacency matrix as proposed in tkipf/gcn#4 .

From here it should be simple to adapt the code ;-)

from pygat.

LeeJunHyun avatar LeeJunHyun commented on August 24, 2024

Oh, Thanks.
Here is a mention on the paper about this issue.
But, you gave me a kind of solution! I will try it.

"""
<subsection 2.2>
the tensor manipulation framework we used only supports sparse matrix multiplication for rank-2 tensors, which limits the batching capabilities of the layer as it is currently implemented (especially for datasets with multiple graphs).
"""

from pygat.

Diego999 avatar Diego999 commented on August 24, 2024

Hi LeeJunHyun,

It seems to be a problem only for sparse matrix multiplication

from pygat.

Diego999 avatar Diego999 commented on August 24, 2024

Hi @Cartus,

I didn't run your code but it seems fine. You should run it (even on cpu) just to see if it runs without any problem

from pygat.

tbright17 avatar tbright17 commented on August 24, 2024

Batching will make
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
pretty large, since it becomes batch_size X N^2 X (2*feat_dim). So, if you have a graph with over hundreds of nodes and each node is with over hundreds of dimensions of features, this probably will take all your memory.

To avoid this, I made a simple change to the code, just replacing the calculation of similarity matrix with a simple node-wise dot production. On Cora dataset, it can produce almost the same accuracy (around 84%) with the current implementation in 100 epochs. But I'm not sure whether this will degrade the performance on other datasets. Just FYI.

from pygat.

wxy920801 avatar wxy920801 commented on August 24, 2024

Batching will make
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
pretty large, since it becomes batch_size X N^2 X (2*feat_dim). So, if you have a graph with over hundreds of nodes and each node is with over hundreds of dimensions of features, this probably will take all your memory.

To avoid this, I made a simple change to the code, just replacing the calculation of similarity matrix with a simple node-wise dot production. On Cora dataset, it can produce almost the same accuracy (around 84%) with the current implementation in 100 epochs. But I'm not sure whether this will degrade the performance on other datasets. Just FYI.

Could you share your code, please? this bother me too.

from pygat.

swg209 avatar swg209 commented on August 24, 2024

@Cartus Could you please share the whole GraphAttentioLayer code , by the way, what is a1, a2? Thx.

from pygat.

lizhenstat avatar lizhenstat commented on August 24, 2024

@swg209 a1,a2 is defined in this issue. #4 (comment)

from pygat.

lizhenstat avatar lizhenstat commented on August 24, 2024

@Cartus Hi, thanks for your code on batched data on GAT. I have one question here, why attention = F.softmax(attention, dim=1) why dim=1 here, why isn't attention = F.softmax(attention, dim=2), since adding batch to the first dimension.

Thanks a lot.

from pygat.

Nancyuru avatar Nancyuru commented on August 24, 2024

嗨,迭戈,

感谢您的杰出工作!请问我是否可以通过这种方式来实现分批培训。
def forward(自我,输入,调整):

    batch_size = input.size(0)
    h = torch.bmm(input, self.W.expand(batch_size, self.in_features, self.out_features))

    f_1 = torch.bmm(h, self.a1.expand(batch_size, self.out_features, 1))
    f_2 = torch.bmm(h, self.a2.expand(batch_size, self.out_features, 1))
    e = self.leakyrelu(f_1 + f_2.transpose(2,1))

    zero_vec = -9e15*torch.ones_like(e)
    attention = torch.where(adj > 0, e, zero_vec)
    attention = F.softmax(attention, dim=1)
    attention = F.dropout(attention, self.dropout, training=self.training)
    h_prime = torch.bmm(attention, h)

    if self.concat:
        return F.elu(h_prime)
    else:
        return h_prime`

谢谢!

If I want to apply your code to process data in batches, where should I add it, or can you share the complete GraphAttentionLayer code.Thank you very much!

from pygat.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.