Comments (11)
Hi Diego,
Thanks for your great work! May I ask if I can implement it in this way to support batch-wise training.
`def forward(self, input, adj):
batch_size = input.size(0)
h = torch.bmm(input, self.W.expand(batch_size, self.in_features, self.out_features))
f_1 = torch.bmm(h, self.a1.expand(batch_size, self.out_features, 1))
f_2 = torch.bmm(h, self.a2.expand(batch_size, self.out_features, 1))
e = self.leakyrelu(f_1 + f_2.transpose(2,1))
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.bmm(attention, h)
if self.concat:
return F.elu(h_prime)
else:
return h_prime`
Thank you!
from pygat.
Hi LeeJunHyun,
I didn't implement it to handle batched graphs. However, it can be done by using a block-diagonal adjacency matrix as proposed in tkipf/gcn#4 .
From here it should be simple to adapt the code ;-)
from pygat.
Oh, Thanks.
Here is a mention on the paper about this issue.
But, you gave me a kind of solution! I will try it.
"""
<subsection 2.2>
the tensor manipulation framework we used only supports sparse matrix multiplication for rank-2 tensors, which limits the batching capabilities of the layer as it is currently implemented (especially for datasets with multiple graphs).
"""
from pygat.
Hi LeeJunHyun,
It seems to be a problem only for sparse matrix multiplication
from pygat.
Hi @Cartus,
I didn't run your code but it seems fine. You should run it (even on cpu) just to see if it runs without any problem
from pygat.
Batching will make
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
pretty large, since it becomes batch_size X N^2 X (2*feat_dim). So, if you have a graph with over hundreds of nodes and each node is with over hundreds of dimensions of features, this probably will take all your memory.
To avoid this, I made a simple change to the code, just replacing the calculation of similarity matrix with a simple node-wise dot production. On Cora dataset, it can produce almost the same accuracy (around 84%) with the current implementation in 100 epochs. But I'm not sure whether this will degrade the performance on other datasets. Just FYI.
from pygat.
Batching will make
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
pretty large, since it becomes batch_size X N^2 X (2*feat_dim). So, if you have a graph with over hundreds of nodes and each node is with over hundreds of dimensions of features, this probably will take all your memory.To avoid this, I made a simple change to the code, just replacing the calculation of similarity matrix with a simple node-wise dot production. On Cora dataset, it can produce almost the same accuracy (around 84%) with the current implementation in 100 epochs. But I'm not sure whether this will degrade the performance on other datasets. Just FYI.
Could you share your code, please? this bother me too.
from pygat.
@Cartus Could you please share the whole GraphAttentioLayer code , by the way, what is a1, a2? Thx.
from pygat.
@swg209 a1,a2 is defined in this issue. #4 (comment)
from pygat.
@Cartus Hi, thanks for your code on batched data on GAT. I have one question here, why attention = F.softmax(attention, dim=1) why dim=1 here, why isn't attention = F.softmax(attention, dim=2), since adding batch to the first dimension.
Thanks a lot.
from pygat.
嗨,迭戈,
感谢您的杰出工作!请问我是否可以通过这种方式来实现分批培训。
def forward(自我,输入,调整):batch_size = input.size(0) h = torch.bmm(input, self.W.expand(batch_size, self.in_features, self.out_features)) f_1 = torch.bmm(h, self.a1.expand(batch_size, self.out_features, 1)) f_2 = torch.bmm(h, self.a2.expand(batch_size, self.out_features, 1)) e = self.leakyrelu(f_1 + f_2.transpose(2,1)) zero_vec = -9e15*torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) attention = F.softmax(attention, dim=1) attention = F.dropout(attention, self.dropout, training=self.training) h_prime = torch.bmm(attention, h) if self.concat: return F.elu(h_prime) else: return h_prime`
谢谢!
If I want to apply your code to process data in batches, where should I add it, or can you share the complete GraphAttentionLayer code.Thank you very much!
from pygat.
Related Issues (20)
- according to the definition of softmax this line maybe wrong HOT 1
- Model instability
- How to implement batch training? HOT 24
- How to implement the GAT model to a regression problem? **Particularly the design of labels**
- Getting this error!
- transform to other scope dataset
- How to visualize the learned Attention?
- The result score is acc=84.6
- The bias is not necessary?
- code question HOT 1
- Why use plural title?
- Expected object of type torch.cuda.LongTensor but found type torch.LongTensor for argument #1 'indices'
- Parameter containing nan
- can GAT convert to caffemodel?
- runtime error HOT 1
- About DataParallel , multi gpu
- Why batch training? HOT 2
- error HOT 1
- How to apply this model to extract graph features from multiple graphs? HOT 3
- pyGAT?pyG GAT? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pygat.