Coder Social home page Coder Social logo

Comments (14)

looper99 avatar looper99 commented on June 26, 2024 1

Yes, of course, returning result, xs[-1].
Thank you once again for fast replies and resolving my problem.

from s5-pytorch.

i404788 avatar i404788 commented on June 26, 2024

Hey,

apply_ssm is the parallel formulation of S5 if you want to carry state you should use forward_rnn

def forward_rnn(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0):
with initial_state
def initial_state(self, batch_size: Optional[int] = None):
as prev_state.

This code path hasn't been tested too much though so there might be bugs (let me know).

For training you usually want to use the parallel formulation which increases the speed of training and reduces memory (if using regular autograd), and use forward_rnn for inference speed/memory.

from s5-pytorch.

i404788 avatar i404788 commented on June 26, 2024

If you want to mix them I think you'll need to extract the last element of xs from apply_ssm:

- return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du
+ return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1]

from s5-pytorch.

looper99 avatar looper99 commented on June 26, 2024

I understand, but what I meant by this question is if I now return xs[-1] how can it be reused in apply_ssm because I tried:
Lambda_bars = Lambda_bars * prev_state (where prev_state is returned xs[-1])
_, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements))
But this didn't work.

What I am talking about is a functionality as in S4's state forwarding mode:
https://github.com/HazyResearch/state-spaces/tree/main/models/s4#state-forwarding

from s5-pytorch.

i404788 avatar i404788 commented on June 26, 2024

Ah I see, referencing the paper it seems like you would need to inject it as the initial state of associative_scan however the jax impl I ported doesn't actually support that.

I think Lambda_bars[0] = Lambda_bars[0] * prev_state w/ prev_state = xs[-1] should do the equivalent.
Note this would be after it has been tiled (just before the first associative_scan call)

from s5-pytorch.

looper99 avatar looper99 commented on June 26, 2024

I see, so:

if Lambda_bars.ndim == 1: # Repeat for associative_scan
        Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1)

Lambda_bars[0] = Lambda_bars[0] * prev_state
_, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements))

should do it?

Then I can have apply_ssm with prev_state:
apply_ssm(Lambda_bars: torch.Tensor, B_bars, C_tilde, D, input_sequence, prev_state, bidir: bool = False)
and carry it on:
LOOP:
x, states = s5(x, states)
END:

from s5-pytorch.

i404788 avatar i404788 commented on June 26, 2024

Yes, with the other return patch, that should be correct. I'll probably add this functionality to the main repo after it's validated.

from s5-pytorch.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.