Coder Social home page Coder Social logo

davidmrau / mixture-of-experts Goto Github PK

View Code? Open in Web Editor NEW
848.0 3.0 91.0 75 KB

PyTorch Re-Implementation of "The Sparsely-Gated Mixture-of-Experts Layer" by Noam Shazeer et al. https://arxiv.org/abs/1701.06538

License: GNU General Public License v3.0

Python 100.00%
moe mixture-of-experts sparsely-gated-mixture-of-experts pytorch re-implementation

mixture-of-experts's People

Contributors

davidmrau avatar elias-ramzi avatar inisis avatar panmianzhi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

mixture-of-experts's Issues

Why logsoftmax in the expert's output?

Hi, thanks for the rep!

Could you kindly clarify why do you use nn.LogSoftmax( ) in the last layer of the MLP experts? Thanks!

I'm guessing the reason is because you are using NLL loss in one of the examples. However, in the CIFAR example crossentropyloss is used.

Why there is prob_if_in/out in MoE-Loss-load?

Can you explain that there is "prob_if_in" , "prob_if_out" in "MoE._prob_in_top_k"?
I am a little confused that the original paper doesn't talk about the L_load needs two kinds of prob?

why apply exp() log() in expert_out result in combine() function of SparseDispatcher class

    def combine(self, expert_out, multiply_by_gates=True):
        """Sum together the expert output, weighted by the gates.
        The slice corresponding to a particular batch element `b` is computed
        as the sum over all experts `i` of the expert output, weighted by the
        corresponding gate values.  If `multiply_by_gates` is set to False, the
        gate values are ignored.
        Args:
          expert_out: a list of `num_experts` `Tensor`s, each with shape
            `[expert_batch_size_i, <extra_output_dims>]`.
          multiply_by_gates: a boolean
        Returns:
          a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
        """
        # apply exp to expert outputs, so we are not longer in log space
        stitched = torch.cat(expert_out, 0).exp()

        if multiply_by_gates:
            stitched = stitched.mul(self._nonzero_gates)
        zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device)
        # combine samples that have been processed by the same k experts
        combined = zeros.index_add(0, self._batch_index, stitched.float())
        # add eps to all zero values in order to avoid nans when going back to log space
        combined[combined == 0] = np.finfo(float).eps
        # back to log space
        return combined.log()

I'm confusing about the applied exp() and log() in above code.

if we just want to predict one data item in inference, do I need to apply these functions ?

requires_grad = True not required for a variable under combine() method?

Inside the combine() under SparseDispatcher, there is a line:

zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device)

It seems to me 'requires_grad=True' is not required, as 'zeros' is not any weight or parameters to be learned. Any specific reason to set it to True?

Why is weighted sum calculated in the logarithmic space?

In the implementation of combine in class SparseDispatcher, the code first apply exp(), then calculate weighted sum and finally go back to log space. Why to do that? I think the result is not as same as the original paper.
In the paper, we have
y = sum(G(x) * E(x))
but in your code, I think you calculate
y = log(sum(G(x) * exp(E(X)) ) )
It seems not the same

Question about the noisy top-k gating

Hi! Thanks for your implementation of MoE! I have confused about the derivatives of w_gate and w_noise. It seems the computation of logits:

https://github.com/davidmrau/mixture-of-experts/blob/master/moe.py#L239

is not differentiable because of the top-k operation. So the w_gate and w_noise can not be updated from the NLLloss. Not sure is the appropriate way to train the MoE.

MoE for transformers

Hi,

I want to use the MoE inside a transformer model. So instead on the current input size of shape [batch_size, input_size] I have an input size of size [batch_size, sequence_length, input_size]. Do you know how can I make this work?

Best,
Elias

Issue with gates parameters

Hi,

Thank you for this repo!

I think there is an issue with the gates parameters defined in the MoE as:

self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True).to(self.device)
self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True).to(self.device)

In my version of torch (1.11.0) applying .to(self.device) to a nn.Parameter returns a nn.Tensor, so the weights are not learned during training.
A simple fix is to juste remove the .to(self.device) as the gates weights are registered by torch as model parameters and can be set to the correct device outside of the __init__.
Again I don't know if this is specific to my torch version, but this prohibited learning the MoE for me.

Hope this helps!

Zero Grad of w_gate

I implement MoE in the Transformer layer, and I discard the load-balance loss l_aux during my training. However, I found that the gradient of the w_gate is always zero and that parameter does not update. Is this because I don't consider l_aux ?

some questions about the code

Hello, your code has provided me with great inspiration for my work. I also have some questions to ask you. Firstly, why do we need to calculate the self.cv_squared loss for both importance and load? Secondly, in _prob_in_top_k, can we replace prob_if_out with 1 - prob_if_in?

cv_squared

Hi David,

Thanks for your great code!
I think there is a small mistake in your cv_squared function.
Seems it should be return x.float().var() / (x.float().mean()**2 + eps)

Thanks

about aux_loss

Hello, I would like to ask what will be the overall trend of aux_loss training? After I added MOE to my model, although the loss decreased, the aux_loss kept fluctuating
捕获
.

A question for changing input size of moe

Hi, thanks for your updating!
I noticed that you updated this project recently, and I am wondering if I want to change the input size to torch.Size([1, seq, 64]), what code should I change in the moe.py file?

Why not gpu?

README says the code is on cpu, but i want to know whether the code can run on GPU, such as A100?

How to use this layer in a sequence setting?

Hi, I am trying to use the MOE class in the decoder portion of a transformer architecture in which I want to replace the feed forward step with a mixture of experts. The input dimension of the class is of type [batch, input_size]. The sequence in each step is variable which leads to a variable input size. How can I use this class in that case

Wrong Implementation in SparseDispatcher

code line 57 self._batch_index = sorted_experts[index_sorted_experts[:, 1],0]
should change to self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1],0]

multiple_by_gates after exp

Hi,

Thank for your pytorch implementation of sparse MOE! In the combine method of SparseDispatcher, the stitched is transformed by exp and then multiply_by_gates. I wonder if the stitched should be first multiplied by gates and then transformed by exp, which is consistent with the tf implementation.

def combine(self, expert_out, multiply_by_gates=True):
        # apply exp to expert outputs, so we are not longer in log space
        stitched = torch.cat(expert_out, 0).exp()

        if multiply_by_gates:
            stitched = stitched.mul(self._nonzero_gates)

Log and Exp- Space

Hey,

thank you for the pytorch implementation. In your combine implementation you transform the expert outputs by Exp and after that you transform it back. Can you explain, why you doing this.

    def combine(self, expert_out, multiply_by_gates=True):
        """Sum together the expert output, weighted by the gates.
        The slice corresponding to a particular batch element `b` is computed
        as the sum over all experts `i` of the expert output, weighted by the
        corresponding gate values.  If `multiply_by_gates` is set to False, the
        gate values are ignored.
        Args:
          expert_out: a list of `num_experts` `Tensor`s, each with shape
            `[expert_batch_size_i, <extra_output_dims>]`.
          multiply_by_gates: a boolean
        Returns:
          a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
        """
        # apply exp to expert outputs, so we are not longer in log space
        stitched = torch.cat(expert_out, 0).exp()

        if multiply_by_gates:
            stitched = stitched.mul(self._nonzero_gates)
        zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True)
        # combine samples that have been processed by the same k experts
        combined = zeros.index_add(0, self._batch_index, stitched.float())
        # add eps to all zero values in order to avoid nans when going back to log space
        combined[combined == 0] = np.finfo(float).eps
        # back to log space
        return combined.log()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.