Coder Social home page Coder Social logo

set_transformer's People

Contributors

yoonholee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

set_transformer's Issues

Inputs of the SetTransformer

Hi,

Could you please explain the meanings of the inputs of SetTransformer:

dim_input, num_outputs, dim_output, num_inds=32, dim_hidden=128, num_heads=4, ln=False

Thanks.

A little puzzle about the implementation details.

Hi juho-lee!
I have two little puzzles about your paper. In section 1-Introduction. You said "A model for set-input problems should satisfy two critical requirements. First, it should be permutation invariant the output of the model should not change under any permutation of the elements in the input set. Second, such a model should be able to process input sets of any size."
But after reading the whole paper, I actually didn't know how you tackle with these two problems.
For problem 1, I guess you may remove the position embedding from the initial Transformers?
As for problem 2, I had totally no idea how you achieved it.
Thank you!

Question about ISAB

Not sure if I understand the Induced Set Attention Block correctly.

So basically SAB is a transformer without positional encoding (and dropout?). In the paper, you said that SAB is "too expensive for large sets". But set size here refers to the max sequence length in a transformer which is usually 512. Why not just use the SAB for SetTransformer? Is there any reason other than efficiency, to use ISAB for SetTransformer?

Why is LayerNorm default to False?

Not an issue, but a question: why is the default LayerNorm function set to False? In particular, for the point cloud example, the LayerNorm is not used.

Can you comment on the importance of having the nested LayerNorm activated for the model? That is, in the paper there was not exposition on having LayerNorm activated versus not.

Thanks!

4-D equivalent?

What if I have a set of matrices instead of a set of vectors? Is it possible to extend the Set Transformer framework to cover that scenario?

I played around with it a little (including making some small tweaks) but got bogged down with the .bmm call in the MAB module:

RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)

LayerNorm

Dear Juho,
Thanks for making the code public!
One quick question, if I read the code correctly, LayerNorm was never used in any of the three examples you opensourced here in this repo is that correct?
If so, is it because they give bit inferior performances? And have you tried moving the LayerNorm layer inside the skip connections instead of before/after the skip connections like done in several more recent papers such that you have an connection directly from output to input?
Thanks in advance and looking forward to your reply!

MAB Implementation diverges from Paper

Dear Juho,

is it possible that the implementation of the MAB diverges from the paper?

In more detail: The paper states

Multihead(Q,K,V;λ,ω)=concat(O_1,··· ,O_h)W_O
H = LayerNorm(X + Multihead(X, Y, Y ; ω))
MAB(X, Y ) = LayerNorm(H + rFF(H))

but the code does

A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)  # This is output of multihead
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
  • It seems that the matrix W_O is not being used in the code at all to mix the output of the different heads?

  • The skip connection Q_ + A.bmm(V_) also diverges from what's stated in the paper, given that Q_ is derived from Q which gets linearly transformed via Q = self.fc_q(Q) in the first line of forward() and is therefore no longer equal to the original query. (On second thought, this may be a necessary requirement, since the output of the MAB has different shape than the input shape. That means in this case, the paper is imprecise.)

Thanks a lot and best wishes
Jannik

Question about model's input

Hi juho-lee,
I have many sets, each of which has a different size. I want to take some sets as a mini-batch for set-transformer model. But I find that every set in a mini-batch must have same size. Have you ever face this problem? How did you deal with it? Padding or other methods?
thank you!

Question about Deep Sets Implementation

Hi @juho-lee,

First of all, thanks for making this code publicly available. It's very useful.

One question, though. I am looking at your implementation of the Zaheer et al network ("Deep Sets.") In his paper, we have something like rho(sum (phi(x))), where we are adding over each element of the set (I believe you call this a set pooling method in your paper )

In your DeepSet class, we have a succession of Linear -> ReLU -> Linear -> ReLU layers, that operate on the entire data set, and then are pooled at the end.

Could you explain a little about why these are equivalent?

Question on dim_split in MAB

Hello,

Would you please explain the necessity to use dim_split in MAB?
For e.g. if I have a batch of 2x387x768 I see the A tensor has shape 24x387x387 because it is using Q_ instead of Q

Would appreciate your response!

Thank you!
Sharmi

question about the network architecture for set transformer

Hi, @yoonholee ,

Thanks a lot for adding the code for the point cloud part. After looking into the network part, it shows that SAB modules are not included in decoder part? Is that the reason of increased time complexity when appending SAB modules to enhance the expressiveness of representations ? It seems that the classification accuracy will be increased by doing so. Had you performed the related experiments?

THX!

Question about the normalization in the attention weight calculation

Hi!

I would like to ask you about the 1/sqrt(self.dim_V) normalization in the MAB inside the softmax function. Usually the attention scaling is implemented with the reciprocal of the dimensionality of the key, and since here the dim_V is split up into num_heads equal parts the size of the key vectors are dim_V//num_heads.

Is this something intentional or a "bug"? Although calling it a bug is an over exaggeration since it only introduces an extra 1/sqrt(num_heads) scale.

If this is unintentional, I'm happy to make a pull request, although it's only changing a word or if it was something intentional could you explain the idea behind it?

Thanks!

License

Hello,

Just read your paper and was very happy to see that you've made this implementation available. Would you be willing to add a license to this repo (MIT, for instance), so that others can build on this code?

PMA implementation missing rFF?

Dear Juho,

First of all, thank you for the implementation! It has been very helpful to my understanding of the architecture.

I ran into an alleged discrepancy between code and paper, and I was wondering if you could help clear this up. In particular, it seems to me that the PMA implementation is missing the row-wise feed-forward layer that is mentioned in the paper:

PMA(S, Z) = MAB(S, rFF(Z))

The PMA code:

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)

To me this reads PMA(S, X) = MAB(S, X), rather than the MAB(S, rFF(X)) of the paper.

Thanks!

Tim

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.