Coder Social home page Coder Social logo

i404788 / s5-pytorch Goto Github PK

View Code? Open in Web Editor NEW
48.0 1.0 2.0 59 KB

Pytorch implementation of Simplified Structured State-Spaces for Sequence Modeling (S5)

License: Mozilla Public License 2.0

Python 100.00%
pytorch s5 sequence-modeling state-space

s5-pytorch's Introduction

S5: Simplified State Space Layers for Sequence Modeling

This is a ported version derived from https://github.com/lindermanlab/S5 and https://github.com/kavorite/S5. It includes a bunch of functions ported from jax/lax/flax/whatever since they didn't exist yet.

Jax is required because it relies on the pytree structure but it's not used for any computation. Since version 0.2.0 jax is not required, it's using the pytorch native torch.utils._pytree (this may be incompatible for pytorch future versions). Pytorch 2 or later is required because it makes heavy use of torch.vmap and torch.utils._pytree to substitute it's jax counterpart. Python 3.10 or later is required due to usage of the match keyword

---

Update:

In my experiments it follows the results found in the Hyena Hierarchy (& H3) paper that the state spaces alone lack the recall capabilities required for LLM but seem work well for regular sequence feature extraction and linear complexity.

You can use variable step-size as described in the paper using a 1D tensor for step_scale however this takes a lot of memory due to a lot of intermediate values needing to be held (which I believe is true for the official S5 repo, but not mentioned in the paper unless I missed it).

Install

pip install s5-pytorch 

Example

from s5 import S5, S5Block

# Raw S5 operator
x = torch.rand([2, 256, 32])
model = S5(32, 32)
model(x) # [2, 256, 32]

# S5-former block (S5+FFN-GLU w/ layernorm, dropout & residual)
model = S5Block(32, 32, False)
model(x) # [2, 256, 32]

s5-pytorch's People

Contributors

i404788 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

sustcsonglin

s5-pytorch's Issues

Pendulum task

Dear @i404788 ,
do you maybe have an S5 model that works with irregularly sampled data, I mean it can be trained on it, like in pendulum task?

If yes, could you share this branch?

Thanks.

Should the S5 layer be faster than a RNN?

I tried running the following script and found that S5 is far slower than PyTorch's LSTM. Is this supposed to be the case? Perhaps the scale at which I'm testing it is too small to realize the benefit?

from datetime import datetime
import os

import torch

from s5 import S5

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

L = 1200
B = 256
x_dim = 128
m = S5(x_dim, 512).cuda()
lstm = torch.nn.LSTM(obs_size, 512).cuda()
x = torch.randn(B, L, obs_size).cuda()

t0 = datetime.now()
for i in range(10):
    y, _ = lstm(x)
    torch.sum(y).backward()
t1 = datetime.now()
print(t1 - t0)

t2 = datetime.now()
for i in range(10):
    y = m(x)
    torch.sum(y).backward()
t3 = datetime.now()
print(t3 - t2)

I would greatly appreciate any comment on this. Thanks in advance, and thanks for the implementation!

How to carry state in apply_ssm?

Dear, how to use prev_state in apply_ssm function since I see it is now purely forward?
I would ideally want to:
x, states = s5(x, states), where apply_ssm carries state such that I can train with memory.

RuntimeError with complex parameter type, Adam and Weight Decay

Hi,

I was trying to use the S5Block with an Adam optimizer with weight decay. However, I got a strange bug, that the sizes of parameters and gradients mismatch. The error only occures with cuda tensors/model and only when weight_decay is enabled. Below a minimal script that reproduces the bug:

from s5 import S5Block
import torch

x = torch.randn(16, 64, 256).cuda()
a = S5Block(256, 128, False).cuda()
a.train()
# h = torch.optim.Adam(a.parameters(), lr=0.001)  # this works
h = torch.optim.Adam(a.parameters(), lr=0.001, weight_decay=0.0001)  # this doesn't work

out = a(x.cuda())
out.sum().backward()
h.step()

After a lot of digging I found the part that caused the error: complex data type handling of device parameters is faulty in the _multi_tensor_adam in the newest version 2.0.1 of pytorch. Specifically in L. 442 in torch/optim/adam.py was a wrong variable used for computing the weight decay.

However, this seems to have been fixed since May 9 with this commit. So with a newer pytorch version this should be working. Right now, this remains broken.

Just posting this here in case anyone else is having this issue.

torchinfo.summary() in example.py fails

I cloned the repo and running the code just as is from the example.py file fails with RuntimeError: accessing 'data' under vmap transform is not allowed.

Click for stack trace
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File [/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:295](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:295), in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    [294](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:294) if isinstance(x, (list, tuple)):
--> [295](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:295)     _ = model(*x, **kwargs)
    [296](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:296) elif isinstance(x, dict):

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1517](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1568](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1568), in Module._call_impl(self, *args, **kwargs)
   [1566](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1566)     args = bw_hook.setup_input_hook(args)
-> [1568](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1568) result = forward_call(*args, **kwargs)
   [1569](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1569) if _global_forward_hooks or self._forward_hooks:

File [~/ML-data/s5_pytorch/s5_model.py:426](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:426), in S5Block.forward(self, x)
    [425](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:425) res = fx.clone()
--> [426](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:426) x = F.gelu(self.s5(fx)) + res
    [427](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:427) x = self.attn_dropout(x)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1517](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1568](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1568), in Module._call_impl(self, *args, **kwargs)
   [1566](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1566)     args = bw_hook.setup_input_hook(args)
-> [1568](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1568) result = forward_call(*args, **kwargs)
   [1569](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1569) if _global_forward_hooks or self._forward_hooks:

File [~/ML-data/s5_pytorch/s5_model.py:377](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:377), in S5.forward(self, signal, step_scale)
    [375](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:375)     step_scale = torch.ones(signal.shape[0], device=signal.device) * step_scale
--> [377](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:377) return torch.vmap(lambda s, ss: self.seq(s, step_scale=ss))(signal, step_scale)

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/apis.py:188](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/apis.py:188), in vmap.<locals>.wrapped(*args, **kwargs)
    [187](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/apis.py:187) def wrapped(*args, **kwargs):
--> [188](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/apis.py:188)     return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:266](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:266), in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    [265](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:265) # If chunk_size is not specified.
--> [266](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:266) return _flat_vmap(
    [267](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:267)     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    [268](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:268) )

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:38](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:38), in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     [37](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:37) with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> [38](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:38)     return f(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:379](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:379), in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    [378](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:378) batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> [379](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:379) batched_outputs = func(*batched_inputs, **kwargs)
    [380](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py:380) return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)

File [~/ML-data/s5_pytorch/s5_model.py:377](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:377), in S5.forward.<locals>.<lambda>(s, ss)
    [375](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:375)     step_scale = torch.ones(signal.shape[0], device=signal.device) * step_scale
--> [377](https://file+.vscode-resource.vscode-cdn.net/Users/karllundgrens/Nextcloud/Skola/Chalmers/Year%202/Cetasol-thesis/Masters-Thesis/S5/~/ML-data/s5_pytorch/s5_model.py:377) return torch.vmap(lambda s, ss: self.seq(s, step_scale=ss))(signal, step_scale)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1517](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1581](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1581), in Module._call_impl(self, *args, **kwargs)
   [1580](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1580) else:
-> [1581](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1581)     hook_result = hook(self, args, result)
   [1583](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1583) if hook_result is not None:

File [/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:597](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:597), in construct_hook.<locals>.hook(module, inputs, outputs)
    [596](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:596)     info.calculate_num_params()
--> [597](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:597) info.input_size, _ = info.calculate_size(inputs, batch_dim)
    [598](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:598) info.output_size, elem_bytes = info.calculate_size(outputs, batch_dim)

File [/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:104](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:104), in LayerInfo.calculate_size(inputs, batch_dim)
    [100](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:100) # pack_padded_seq and pad_packed_seq store feature into data attribute
    [101](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:101) elif (
    [102](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:102)     isinstance(inputs, (list, tuple))
    [103](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:103)     and inputs
--> [104](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:104)     and hasattr(inputs[0], "data")
    [105](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:105)     and hasattr(inputs[0].data, "size")
    [106](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:106) ):
    [107](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/layer_info.py:107)     size = list(inputs[0].data.size())

RuntimeError: accessing `data` under vmap transform is not allowed

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[7], [line 11](vscode-notebook-cell:?execution_count=7&line=11)
      [8](vscode-notebook-cell:?execution_count=7&line=8) # model = S5(32, 32)
      [9](vscode-notebook-cell:?execution_count=7&line=9) model = S5Block(dim, 512, block_count=8, bidir=False)
---> [11](vscode-notebook-cell:?execution_count=7&line=11) print(torchinfo.summary(model, (2, 8192, dim), device='cpu', depth=5))
     [13](vscode-notebook-cell:?execution_count=7&line=13) for i in range(5):
     [14](vscode-notebook-cell:?execution_count=7&line=14)     y = model(x)

File [/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:223](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:223), in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, mode, row_settings, verbose, **kwargs)
    [216](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:216) validate_user_params(
    [217](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:217)     input_data, input_size, columns, col_width, device, dtypes, verbose
    [218](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:218) )
    [220](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:220) x, correct_input_size = process_input(
    [221](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:221)     input_data, input_size, batch_dim, device, dtypes
    [222](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:222) )
--> [223](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:223) summary_list = forward_pass(
    [224](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:224)     model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
    [225](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:225) )
    [226](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:226) formatting = FormattingOptions(depth, verbose, columns, col_width, rows)
    [227](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:227) results = ModelStatistics(
    [228](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:228)     summary_list, correct_input_size, get_total_memory_used(x), formatting
    [229](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:229) )

File [/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:304](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:304), in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    [302](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:302) except Exception as e:
    [303](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:303)     executed_layers = [layer for layer in summary_list if layer.executed]
--> [304](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:304)     raise RuntimeError(
    [305](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:305)         "Failed to run torchinfo. See above stack traces for more details. "
    [306](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:306)         f"Executed layers up to: {executed_layers}"
    [307](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:307)     ) from e
    [308](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:308) finally:
    [309](https://file+.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/torchinfo/torchinfo.py:309)     if hooks:

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [LayerNorm: 1]

Other toy models tested with torchinfo.summary() works without issue. Any idea?

A bug seems to be caused by the compatibility of JAX and PyTorch

I created two instances of the S5 models and I tried to print the gradients of these two models:

from s5 import S5
import torch.nn as nn
import torch

input_dim, output_dim, state_width, sequence_length = 1, 1, 1, 1
x1 = torch.randn(1, 1, 1)
y1 = torch.randn(1, 1, 1)
x2 = torch.randn(1, 1, 1)
y2 = torch.randn(1, 1, 1)
model1 = S5(width=input_dim, state_width=state_width)
model2 = S5(width=input_dim, state_width=state_width)

criterion = nn.MSELoss()

loss1 = criterion(model1(x1), y1)
loss2 = criterion(model2(x2), y2)
loss1.backward()
loss2.backward()
for name, param in model1.named_parameters():
    print(name, param.data, param.grad)

for name, param in model2.named_parameters():
    print(name, param.data, param.grad)

The output is as following:

seq.Lambda tensor([-0.5000+0.j]) tensor([-3.8413e-08-6.8010e-07j])
seq.B tensor([[[0.1249, 0.0000]]]) tensor([[[-0.0001, -0.0022]]])
seq.C tensor([[0.0452-0.8008j]]) tensor([[-0.0003+0.j]])
seq.D tensor([0.9082]) tensor([-0.5527])
seq.log_step tensor([-5.3020]) tensor([-1.5514e-05])

seq.Lambda tensor([-0.5000+0.j]) None
seq.B tensor([[[-0.4863,  0.0000]]]) None
seq.C tensor([[0.2791+0.8068j]]) tensor([[-0.0208+0.j]])
seq.D tensor([0.5738]) tensor([4.2298])
seq.log_step tensor([-4.5913]) None

I don't understand why part of model2's gradients are missing?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.