juho-lee / set_transformer Goto Github PK
View Code? Open in Web Editor NEWPytorch implementation of set transformer
License: MIT License
Pytorch implementation of set transformer
License: MIT License
Hi! As introduced in the paper, you also experimented with Amortized Clustering on CIFAR-10 with SetTransformer. Yet I did not find the code in the repo, could you make that part of code available as well? Thanks!
Hi,
Could you please explain the meanings of the inputs of SetTransformer:
dim_input, num_outputs, dim_output, num_inds=32, dim_hidden=128, num_heads=4, ln=False
Thanks.
Hi juho-lee!
I have two little puzzles about your paper. In section 1-Introduction. You said "A model for set-input problems should satisfy two critical requirements. First, it should be permutation invariant the output of the model should not change under any permutation of the elements in the input set. Second, such a model should be able to process input sets of any size."
But after reading the whole paper, I actually didn't know how you tackle with these two problems.
For problem 1, I guess you may remove the position embedding from the initial Transformers?
As for problem 2, I had totally no idea how you achieved it.
Thank you!
Not sure if I understand the Induced Set Attention Block correctly.
So basically SAB is a transformer without positional encoding (and dropout?). In the paper, you said that SAB is "too expensive for large sets". But set size here refers to the max sequence length in a transformer which is usually 512. Why not just use the SAB for SetTransformer? Is there any reason other than efficiency, to use ISAB for SetTransformer?
Not an issue, but a question: why is the default LayerNorm function set to False? In particular, for the point cloud example, the LayerNorm is not used.
Can you comment on the importance of having the nested LayerNorm activated for the model? That is, in the paper there was not exposition on having LayerNorm activated versus not.
Thanks!
What if I have a set of matrices instead of a set of vectors? Is it possible to extend the Set Transformer framework to cover that scenario?
I played around with it a little (including making some small tweaks) but got bogged down with the .bmm
call in the MAB module:
RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
Dear Juho,
Thanks for making the code public!
One quick question, if I read the code correctly, LayerNorm
was never used in any of the three examples you opensourced here in this repo is that correct?
If so, is it because they give bit inferior performances? And have you tried moving the LayerNorm
layer inside the skip connections instead of before/after the skip connections like done in several more recent papers such that you have an connection directly from output to input?
Thanks in advance and looking forward to your reply!
Dear Juho,
is it possible that the implementation of the MAB diverges from the paper?
In more detail: The paper states
Multihead(Q,K,V;λ,ω)=concat(O_1,··· ,O_h)W_O
H = LayerNorm(X + Multihead(X, Y, Y ; ω))
MAB(X, Y ) = LayerNorm(H + rFF(H))
but the code does
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) # This is output of multihead
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
It seems that the matrix W_O
is not being used in the code at all to mix the output of the different heads?
The skip connection Q_ + A.bmm(V_)
also diverges from what's stated in the paper, given that Q_
is derived from Q
which gets linearly transformed via Q = self.fc_q(Q)
in the first line of forward()
and is therefore no longer equal to the original query. (On second thought, this may be a necessary requirement, since the output of the MAB has different shape than the input shape. That means in this case, the paper is imprecise.)
Thanks a lot and best wishes
Jannik
I was wondering do you still have the code for part 5.2 counting unique characters? It would be really helpful. Thanks!
Hi juho-lee,
I have many sets, each of which has a different size. I want to take some sets as a mini-batch for set-transformer model. But I find that every set in a mini-batch must have same size. Have you ever face this problem? How did you deal with it? Padding or other methods?
thank you!
Hi, Could you provide the code of Set Anomaly Detection experiment in your paper? thx
Hi @juho-lee,
First of all, thanks for making this code publicly available. It's very useful.
One question, though. I am looking at your implementation of the Zaheer et al network ("Deep Sets.") In his paper, we have something like rho(sum (phi(x)))
, where we are adding over each element of the set (I believe you call this a set pooling method in your paper )
In your DeepSet class, we have a succession of Linear -> ReLU -> Linear -> ReLU
layers, that operate on the entire data set, and then are pooled at the end.
Could you explain a little about why these are equivalent?
Hello,
Would you please explain the necessity to use dim_split
in MAB?
For e.g. if I have a batch of 2x387x768
I see the A
tensor has shape 24x387x387
because it is using Q_
instead of Q
Would appreciate your response!
Thank you!
Sharmi
�
Hi, @yoonholee ,
Thanks a lot for adding the code for the point cloud part. After looking into the network part, it shows that SAB modules are not included in decoder part? Is that the reason of increased time complexity when appending SAB modules to enhance the expressiveness of representations ? It seems that the classification accuracy will be increased by doing so. Had you performed the related experiments?
THX!
Hi, @juho-lee ,
Thanks for releasing this package. Where could I find the code for modelnet 40 shape classification as mentioned in the paper?
THX!
Hi
Just want to know if you have plans to extend the functionality of the code itself instead of using PyTorch's MAB block?
Thank you!
Hi!
I would like to ask you about the 1/sqrt(self.dim_V)
normalization in the MAB inside the softmax function. Usually the attention scaling is implemented with the reciprocal of the dimensionality of the key, and since here the dim_V
is split up into num_heads
equal parts the size of the key vectors are dim_V//num_heads
.
Is this something intentional or a "bug"? Although calling it a bug is an over exaggeration since it only introduces an extra 1/sqrt(num_heads)
scale.
If this is unintentional, I'm happy to make a pull request, although it's only changing a word or if it was something intentional could you explain the idea behind it?
Thanks!
Hello,
Just read your paper and was very happy to see that you've made this implementation available. Would you be willing to add a license to this repo (MIT, for instance), so that others can build on this code?
Dear Juho,
First of all, thank you for the implementation! It has been very helpful to my understanding of the architecture.
I ran into an alleged discrepancy between code and paper, and I was wondering if you could help clear this up. In particular, it seems to me that the PMA implementation is missing the row-wise feed-forward layer that is mentioned in the paper:
PMA(S, Z) = MAB(S, rFF(Z))
The PMA code:
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
To me this reads PMA(S, X) = MAB(S, X), rather than the MAB(S, rFF(X)) of the paper.
Thanks!
Tim
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.