Comments (14)
Yes, of course, returning result, xs[-1].
Thank you once again for fast replies and resolving my problem.
from s5-pytorch.
Hey,
apply_ssm
is the parallel formulation of S5 if you want to carry state you should use forward_rnn
Line 227 in f0fb132
initial_state
Line 313 in f0fb132
This code path hasn't been tested too much though so there might be bugs (let me know).
For training you usually want to use the parallel formulation which increases the speed of training and reduces memory (if using regular autograd), and use forward_rnn for inference speed/memory.
from s5-pytorch.
If you want to mix them I think you'll need to extract the last element of xs
from apply_ssm
:
- return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du
+ return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1]
from s5-pytorch.
I understand, but what I meant by this question is if I now return xs[-1] how can it be reused in apply_ssm because I tried:
Lambda_bars = Lambda_bars * prev_state (where prev_state is returned xs[-1])
_, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements))
But this didn't work.
What I am talking about is a functionality as in S4's state forwarding mode:
https://github.com/HazyResearch/state-spaces/tree/main/models/s4#state-forwarding
from s5-pytorch.
Ah I see, referencing the paper it seems like you would need to inject it as the initial state of associative_scan
however the jax impl I ported doesn't actually support that.
I think Lambda_bars[0] = Lambda_bars[0] * prev_state
w/ prev_state = xs[-1]
should do the equivalent.
Note this would be after it has been tiled (just before the first associative_scan call)
from s5-pytorch.
I see, so:
if Lambda_bars.ndim == 1: # Repeat for associative_scan
Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1)
Lambda_bars[0] = Lambda_bars[0] * prev_state
_, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements))
should do it?
Then I can have apply_ssm with prev_state:
apply_ssm(Lambda_bars: torch.Tensor, B_bars, C_tilde, D, input_sequence, prev_state, bidir: bool = False)
and carry it on:
LOOP:
x, states = s5(x, states)
END:
from s5-pytorch.
Yes, with the other return patch, that should be correct. I'll probably add this functionality to the main repo after it's validated.
from s5-pytorch.
Related Issues (3)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from s5-pytorch.