Coder Social home page Coder Social logo

hazyresearch / state-spaces Goto Github PK

View Code? Open in Web Editor NEW
2.2K 49.0 266.0 43.16 MB

Structured state space sequence models

License: Apache License 2.0

Python 4.13% C++ 0.01% Cuda 0.05% C 0.01% Shell 0.01% Jupyter Notebook 95.79% Makefile 0.01%
state-space-models sequence-models pytorch

state-spaces's Introduction

Structured State Spaces for Sequence Modeling

This repository provides the official implementations and experiments for models related to S4, including HiPPO, LSSL, SaShiMi, DSS, HTTYH, S4D, and S4ND.

Project-specific information for each of these models, including overview of the source code and specific experiment reproductions, can be found under models/.

Table of Contents

Setting up the environment and porting S4 to external codebases:

Using this repository for training models:

Changelog

See CHANGELOG.md

Roadmap

  • More documentation for training from scratch using this repository
  • Compilation of S4 resources and implementations
  • pip package

Setup

Requirements

This repository requires Python 3.9+ and Pytorch 1.10+. It has been tested up to Pytorch 1.13.1. Other packages are listed in requirements.txt. Some care may be needed to make some of the library versions compatible, particularly torch/torchvision/torchaudio/torchtext.

Example installation:

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt

Structured Kernels

A core operation of S4 are the Cauchy and Vandermonde kernels described in the paper. These are very simple matrix multiplications; a naive implementation of these operation can be found in the standalone in the function cauchy_naive and log_vandermonde_naive. However, as the paper describes, this has suboptimal memory usage that currently requires a custom kernel to overcome in PyTorch.

Two more efficient methods are supported. The code will automatically detect if either of these is installed and call the appropriate kernel.

Custom CUDA Kernel

This version is faster but requires manual compilation for each machine environment. Run python setup.py install from the directory extensions/kernels/.

Pykeops

This version is provided by the pykeops library. Installation usually works out of the box with pip install pykeops cmake which are also listed in the requirements file.

Getting Started with S4

S4 Module

Self-contained files for the S4 layer and variants can be found in models/s4/, which includes instructions for calling the module.

See notebooks/ for visualizations explaining some concepts behind HiPPO and S4.

Example Train Script (External Usage)

example.py is a self-contained training script for MNIST and CIFAR that imports the standalone S4 file. The default settings python example.py reaches 88% accuracy on sequential CIFAR with a very simple S4D model of 200k parameters. This script can be used as an example for using S4 variants in external repositories.

Training with this Repository (Internal Usage)

This repository aims to provide a very flexible framework for training sequence models. Many models and datasets are supported.

The basic entrypoint is python -m train, or equivalently

python -m train pipeline=mnist model=s4

which trains an S4 model on the Permuted MNIST dataset. This should get to around 90% after 1 epoch which takes 1-3 minutes depending on GPU.

More examples of using this repository are documented throughout. See Training for an overview.

Optimizer Hyperparameters

One important feature of this codebase is supporting parameters that require different optimizer hyperparameters. In particular, the SSM kernel is particularly sensitive to the $(A, B)$ (and sometimes $\Delta$ parameters), so the learning rate on these parameters is sometimes lowered and the weight decay is always set to $0$.

See the method register in the model (e.g. s4d.py) and the function setup_optimizer in the training script (e.g. example.py) for an examples of how to implement this in external repos.

Training

The core training infrastructure of this repository is based on Pytorch-Lightning with a configuration scheme based on Hydra.

The main entrypoint is train.py and configs are found in configs/.

Data

Basic datasets are auto-downloaded, including MNIST, CIFAR, and Speech Commands. All logic for creating and loading datasets is in src/dataloaders directory. The README inside this subdirectory documents how to download and organize other datasets.

Models

Models are defined in src/models. See the README in this subdirectory for an overview.

Configs and Hyperparameters

Pre-defined configs reproducing end-to-end experiments from the papers are provided, found under project-specific information in models/, such as for the original S4 paper.

Configs can also be easily modified through the command line. An example experiment is

python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null

This uses the Permuted MNIST task with an S4 model with a specified number of layers, backbone dimension, and normalization type.

See configs/README.md for more detailed documentation about the configs.

Hydra

It is recommended to read the Hydra documentation to fully understand the configuration framework. For help launching specific experiments, please file an issue.

Resuming

Each experiment will be logged to its own directory (generated by Hydra) of the form ./outputs/<date>/<time>/. Checkpoints will be saved here inside this folder and printed to console whenever a new checkpoint is created. To resume training, simply point to the desired .ckpt file (a PyTorch Lightning checkpoint, e.g. ./outputs/<date>/<time>/checkpoints/val/loss.ckpt) and append the flag train.ckpt=<path>/<to>/<checkpoint>.ckpt to the original training command.

PyTorch Lightning Trainer

The PTL Trainer class controls the overall training loop and also provides many useful pre-defined flags. Some useful examples are explained below. The full list of allowable flags can be found in the PTL documentation, as well as our trainer configs. See the default trainer config configs/trainer/default.yaml for the most useful options.

Multi-GPU training

Simply pass in trainer.gpus=2 to train with 2 GPUs.

Inspect model layers

trainer.weights_summary=full prints out every layer of the model with their parameter counts. Useful for debugging internals of models.

Data subsampling

trainer.limit_{train,val}_batches={10,0.1} trains (validates) on only 10 batches (0.1 fraction of all batches). Useful for testing the train loop without going through all the data.

WandB

Logging with WandB is built into this repository. In order to use this, simply set your WANDB_API_KEY environment variable, and change the wandb.project attribute of configs/config.yaml (or pass it on the command line e.g. python -m train .... wandb.project=s4).

Set wandb=null to turn off WandB logging.

Generation

Autoregressive generation can be performed with the generate.py script. This script can be used in two ways after training a model using this codebase.

Option 1: Checkpoint Path

The more flexible option requires the checkpoint path of the trained PyTorch Lightning model. The generation script accepts the same config options as the train script, with a few additional flags that are documented in configs/generate.yaml. After training with python -m train <train flags>, generate with

python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>

Any of the flags found in the config can be overridden.

Note: This option can be used with either .ckpt checkpoints (PyTorch Lightning, which includes information for the Trainer) or .pt checkpoints (PyTorch, which is just a model state dict).

Option 2: Experiment Path

The second option for generation does not require passing in training flags again, and instead reads the config from the Hydra experiment folder, along with a PyTorch Lightning checkpoint within the experiment folder.

Example 1 (Language)

Download the WikiText-103 model checkpoint, for example to ./checkpoints/s4-wt103.pt. This model was trained with the command python -m train experiment=lm/s4-wt103. Note that from the config we can see that the model was trained with a receptive field of length 8192.

To generate, run

python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text

This generates a sample of length 16384 conditioned on a prefix of length 8192.

Example 2 (Audio)

Let's train a small SaShiMi model on the SC09 dataset. We can also reduce the number of training and validation batches to get a checkpoint faster:

python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1

After the first epoch completes, a message is printed indicating where the checkpoint is saved.

Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"

Option 1:

python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000

This option redefines the full config so that the model and dataset can be constructed.

Option 2:

python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000

This option only needs the path to the Hydra experiment folder and the desired checkpoint within.

Overall Repository Structure

configs/         Config files for model, data pipeline, training loop, etc.
data/            Default location of raw data
extensions/      CUDA extensions (Cauchy and Vandermonde kernels)
src/             Main source code for models, datasets, etc.
  callbacks/     Training loop utilities (e.g. checkpointing)
  dataloaders/   Dataset and dataloader definitions
  models/        Model definitions
  tasks/         Encoder/decoder modules to interface between data and model backbone
  utils/
models/          Model-specific information (code, experiments, additional resources)
example.py       Example training script for using S4 externally
train.py         Training entrypoint for this repo
generate.py      Autoregressive generation script

Citation

If you use this codebase, or otherwise found our work valuable, please cite S4 and other relevant papers.

@inproceedings{gu2022efficiently,
  title={Efficiently Modeling Long Sequences with Structured State Spaces},
  author={Gu, Albert and Goel, Karan and R\'e, Christopher},
  booktitle={The International Conference on Learning Representations ({ICLR})},
  year={2022}
}

state-spaces's People

Contributors

ad12 avatar albertfgu avatar hongyuhe avatar jchia avatar krandiash avatar rogerni avatar telmop avatar trellixvulnteam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

state-spaces's Issues

S4D Memory Requirements

Hey, I wanted to give S4D a quick try in my research as a drop-in replacement of S4 (which, as far as I gathered, should be a good way to start), but I'm running into some hard memory limitations. I'm trying to train the DiffWave version of SaShiMi as a first experiment, but the memory requirements seem to increase significantly when replacing S4 with an equivalent S4D layer (with default settings), causing the model to go OOM in my case actually (so I don't have any precise measurements, but it's a 20% increase in overall memory consumption at least. I use the parameters as discussed in #46. Is this something you'd expect?

Dropout2d and residual

Dear authors and contributors,

There is an observation that I would be happy to get your confirmation on :-)
In all of the model hierarchy: SequenceModel, SequenceResidualBlock and S4 ,you are using Dropout2d which zeros at the batch dimension, i.e. ignores the entire sample. Without a residual link, with multiple layers, the probability that each sample is not ignored through the model becomes negligible. Consequently, the model does not see the inputs and will not train!
In the SequenceResidualBlock, the dropout is applied only if a residual link is present. The residual link of SequenceResidualBlock also takes care of the dropout from S4.
So my issue is two-fold:

  • When using dropout > 0, we never should set residual = None in the parameters of SequenceResidualBlock, right? Is it possible to add a check in the initialization to avoid possible misconfigurations?
  • The dropinp input of SequenceModel should not be used, as there is no residual link there. I've seen in all of the configs we have dropinp: 0.0. So why is it there at all?

Thanks and regards,

Error when running example.py

Dear all, many thanks for this extremely interesting code (and maths)!

Running
example.py --grayscale
with python 3.8 gives the following error on my system:
"RuntimeError: view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original."

Maybe some requirements need specific versions for the example to run?

PyTorch only cauchy kernel for easier test

Hello,

I loved Structured State Spaces, and obteined a fantastic performance compared to LSTM/SRU/Transformers.

I want to introduce S4 to some researchers and students, and the self-contained S4 layer is super great! However it requires the "cautchy kernel".

There are two versions of cautchy kernel, a cuda and a Pykeops version.

However extensions/cauchy requires cuda and Pykeops do not support Windows, and the target people have a very diverse number of environments.

Would be possible to HazyResearch team to implement a self-contained S4 layer including a simpler pytorch cautchy kernel ?

Training on .mat files

Hi, I am trying to train the Sashimi model on .mat files. How should I go about doing this?

n_epoch_double

Hello thank you for the great library. I was wondering about the n_epoch_double flag and the best way to use it. If you use that flag will it not just give you a cuda OOM error when it doubles? Or is the appropriate usage to begin with a very small batch size so that it has room to double without going OOM? When running the wt103 model if I went over L_max=4096 with a batch size of 1 on a 4xV100 32GB GPU machine I ran out of memory when it got to the eval stage of the first epoch. I have a language modeling use case that requires very long sequence lengths (8192+) and wanted to try the s4 model on it, because the results seem pretty good for lower sequence lengths. Any help would be appreciated. Thank you!

Unused parameters in training

Hi! I'm running some experiments using your code. For my use-case, I'm using torch.nn.DistributedDataParallel, which automatically detects unused parameters, i.e., parameters that get no gradients.

The unused parameters are:

  • D (from the S4 module)
  • output_linear.weight and output_linear.bias (from the S4 module). These are instances of the TransposedLinear layer.
  • kernel.C (from SSKernelNPLR).

I have manually confirmed these parameters don't get gradients by running the following code after computing the loss:

for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

Usually, the above means the parameters are instantiated but not used. In this case, surprisingly, all the parameters get used in the forward method. However, none of them get used in "vanilla" PyTorch ops. D, output_linear.weight and output_linear.bias get used through opt_einsum.contract, and kernel.C gets used through your Cauchy GPU op.

Can you confirm the issue on your end? These parameters all look important for the model.

S4 for Seq2seq tasks?

Hello,
You have shown S4's great ability and efficiency on classification and unconditional generation (wikitext-103) tasks. I am wondering if S4 can be applied to conditional generation tasks such as summarization and machine translation? A simple idea is to re-organize these tasks to language modeling tasks, but I am not sure whether the generation quality would be affected.

decoding LM vocab

hello, I trained a model using something like the wt103 task and modified the sashimi generate script to generate text like a CLM. So basically conditioning on a text string, generate the next N words sequentially in the same loop like the Sashimi generation script. I believe that I have it working however I don't know what integer output corresponds to what word in the vocab. Is there a hash table or something that stores the vocab somewhere that's easily accessible? Sorry I can't seem to find any obvious place that it would reside. Thank you for your help.

Can't compile the custom cauchy kernel

Dear all,
Sorry for a silly question. I'm having trouble install cauchy kernel by custom cuda kernel.
running python setup.py install like this:

~/LongSeq/state-spaces/extensions/cauchy$ python setup.py install
running install
~/miniconda3/envs/lightning/lib/python3.9/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
~/miniconda3/envs/lightning/lib/python3.9/site-packages/setuptools/command/easy_install.py:156: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
running bdist_egg
running egg_info
creating cauchy_mult.egg-info
writing cauchy_mult.egg-info/PKG-INFO
writing dependency_links to cauchy_mult.egg-info/dependency_links.txt
writing top-level names to cauchy_mult.egg-info/top_level.txt
writing manifest file 'cauchy_mult.egg-info/SOURCES.txt'
reading manifest file 'cauchy_mult.egg-info/SOURCES.txt'
writing manifest file 'cauchy_mult.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
warning: install_lib: 'build/lib' does not exist -- no Python modules to install

creating build
creating build/bdist.linux-x86_64
creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
zip_safe flag not set; analyzing archive contents...
creating dist
creating 'dist/cauchy_mult-0.0.0-py3.9.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing cauchy_mult-0.0.0-py3.9.egg
Copying cauchy_mult-0.0.0-py3.9.egg to ~/miniconda3/envs/lightning/lib/python3.9/site-packages
Adding cauchy-mult 0.0.0 to easy-install.pth file

Installed ~/miniconda3/envs/lightning/lib/python3.9/site-packages/cauchy_mult-0.0.0-py3.9.egg
Processing dependencies for cauchy-mult==0.0.0
Finished processing dependencies for cauchy-mult==0.0.0

but cannot import the cauchy_mult when I running test_cauchy.py

~/LongSeq/state-spaces/extensions/cauchy$ python test_cauchy.py 
Traceback (most recent call last):
  File "~/LongSeq/state-spaces/extensions/cauchy/test_cauchy.py", line 8, in <module>
    from cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult
  File "~/LongSeq/state-spaces/extensions/cauchy/cauchy.py", line 5, in <module>
    from cauchy_mult import cauchy_mult_fwd, cauchy_mult_bwd, cauchy_mult_sym_fwd, cauchy_mult_sym_bwd
ModuleNotFoundError: No module named 'cauchy_mult'

I also try to run python -m pip install and change python to 3.8, but they don't work too.

ValueError when running on Pathfinder

Hi, I am getting the following error when trying to train S4 on the pathfinder dataset. Any help would be greatly appreciated.

Traceback (most recent call last):
File "/data/al451/state-spaces/train.py", line 553, in main
train(config)
File "/data/al451/state-spaces/train.py", line 498, in train
trainer.fit(model)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 768, in fit
self._call_and_handle_interrupt(
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1172, in _run
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1492, in _call_setup_hook
self._call_lightning_module_hook("setup", stage=fn)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1593, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/data/al451/state-spaces/train.py", line 56, in setup
self.dataset.setup()
File "/data/al451/state-spaces/src/dataloaders/datasets.py", line 1234, in setup
dataset = PathFinderDataset(self.data_dir, transform=self.default_transforms())
File "/data/al451/state-spaces/src/dataloaders/datasets.py", line 1130, in init
path_list = sorted(
File "/data/al451/state-spaces/src/dataloaders/datasets.py", line 1132, in
key=lambda path: int(path.stem),
ValueError: invalid literal for int() with base 10: '._142'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Training on own data

I'd like to train a sashimi model on my own data. Could you please let me know in what structure the data needs to be in order to be compatible with the dataloder?

load_dataset process getting Killed

  ''' dataset = load_dataset(
        "csv",
        data_files={
            "train": str(self.data_dir / "basic_train.tsv"),
            "val": str(self.data_dir / "basic_val.tsv"),
            "test": str(self.data_dir / "basic_test.tsv"),
        },
        delimiter="\t",
        keep_in_memory=True,
    )''' this piece of code in src/dataloaders/datasets.py is giving error and the dataloading process is aborted

Inconsistent results of forward (training) and step (inference)

Hi, I did a simple test to verify the difference between forward and step (mode="dense") on a single unidirectional S4 layer. Given a random sequence, there difference, the absolute error is around 1e-2 and the square error is around 1e-4. I suspect these results are wrong. My verification follows test_step() in //src/models/sequence/ss/kernel.py. I'd love to know if you have examples that clearly compares their difference. Thanks:)

RNN-style train and eval for S4/S4D

Excellent idea and great paper!
Could you please provide a concrete example on how to both train and eval using the stateful RNN version of S4/S4D? I only find an evaluation example in the SaShiMi code but I have not found an example for training.
Thank you!

Time-series experiments

The experiments using datasets like etth, ettm, and ecl don't run (despite the definitions in /config). I think it's because the datasets cannot be downloaded automatically like the LRA experiments.

Would it be possible to add this, or explain how these experimets could be run?

Should the A, B, C, dt parameters be set to be trainable?

Hi!

When I was checking the code, I found that for most of the experimental configurations for S4, the A, B, C, dt parameters are not set to be trainable, while I originally thought that these parameters are trained in S4 presented in the paper. I don't know if I understand it correctly, but isn't this setting similar to HiPPO if these parameters are not trained? Or do you have any empirical findings on this?

Thanks!

Conditioning on the diffusion step

Hello,

I'm considering trying your Sashimi model as the backbone of a diffusion model for audio generation. There is a detail I couldn't find in the paper, neither in the code (maybe I didn't look enough). How do you condition the model to the diffusion timestep? Do you do the same as in diffwave? (Element-wise addition of the diffusion-step embedding at each layer) Or use something similar to a FiLM layer , as in WaveGrad?

sashimi: LinearActivation initialization gives TypeError for weight_norm argument

Hi,

I bounced upon an issue while attempting to run the sashimi.py script, namely the LinearActivation function in the ./standalone/s4.py script does not accept the weight_norm argument upon initialization passed on through the **kwargs in the DownPool and UpPool classes.

The error trace:

File ".\PycharmProjects\state-spaces\sashimi\sashimi.py", line 450, in <module>
    model = Sashimi(n_layers=2).cuda()

File ".\PycharmProjects\state-spaces\sashimi\sashimi.py", line 298, in __init__
    d_layers.append(DownPool(H, expand, p))

File ".\PycharmProjects\state-spaces\sashimi\sashimi.py", line 28, in __init__
    self.linear = LinearActivation(

File ".\PycharmProjects\state-spaces\src\models\sequence\ss\standalone\s4.py", line 137, in LinearActivation
    linear = linear_cls(d_input, d_output, bias=bias, **kwargs)

TypeError: __init__() got an unexpected keyword argument 'weight_norm'

Looking a bit further I noticed that neither the nn.Conv1d (chosen in DownPool due to transposed = True) nor the nn.Linear that could be called within LinearActivation(), have the explicit weight_norm argument in Pytorch.

Am I overlooking something?

Python: 3.9
Pytorch: 1.12 (latest stable release)

Thanks a lot for publishing your code with the papers!

Cheers,
Bavo

when running for language model exception happend

when I tried to run the code for the language model (wt103) with the command

HYDRA_FULL_ERROR=1 python3 -m train wandb=null experiment=s4-wt103

the exception happened!

in adaptive_softmax.py module, the left index is greater than r_idx so nn.Parameter(torch.zeros(abs(r_idx - l_idx))) can't pass (line 116)! so I did used abs() function to pass it.
then in self.out_layers_weights (line 182) again exception arises. list index out of range so I did used try-except to pass the lines...

You can find the full error in the attached file:
error.txt

Test Accuracy

Hello,

Sorry for another silly question. While training on s4-lra-cifar (on A100) I get the following final log:

Epoch 94, global step 85499: val/accuracy was not in top 1
Epoch 96:  80% 966/1200 [01:16<00:18, 12.63it/s, loss=0.144, v_num=xyj4, val/accuracy=0.866, val/loss=0.527, test/accuracy=0.865, test/loss=0.533, train/accuracy=0.946, train/loss=0.151]
Epoch 95, global step 86399: val/accuracy reached 0.86580 (best 0.86580), saving model to "blah/checkpoints/val/accuracy.ckpt" as top 1
Epoch 97:  78% 938/1200 [01:15<00:21, 12.38it/s, loss=0.142, v_num=xyj4, val/accuracy=0.864, val/loss=0.536, test/accuracy=0.864, test/loss=0.539, train/accuracy=0.945, train/loss=0.156]
Epoch 96, global step 87299: val/accuracy was not in top 1
Epoch 98:  75% 900/1200 [01:14<00:24, 12.06it/s, loss=0.143, v_num=xyj4, val/accuracy=0.866, val/loss=0.529, test/accuracy=0.865, test/loss=0.535, train/accuracy=0.945, train/loss=0.153]
Epoch 97, global step 88199: val/accuracy was not in top 1
Epoch 99:  79% 945/1200 [01:15<00:20, 12.45it/s, loss=0.137, v_num=xyj4, val/accuracy=0.863, val/loss=0.534, test/accuracy=0.864, test/loss=0.539, train/accuracy=0.946, train/loss=0.151]
Epoch 98, global step 89099: val/accuracy was not in top 1

Epoch 99, global step 89999: val/accuracy was not in top 1
Saving latest checkpoint...
  1. Does this mean that the test accuracy measured at the checkpoint with the best validation accuracy is 86.5?
  2. If no, what command should I use to measure the test accuracy that you report in the paper? (I am assuming that the accuracy you report in the paper is indeed the test acc at the checkpoint with best val acc)
  3. You report 87.26 on lra-cifar - is this normal to get this gap? I hope I'm not making a mistake in interpreting the metrics.

I would be grateful if you could help with these questions and apologize in advance if they seem too basic.

Thank you again for sharing you wonderful repo,
Ankit

Multi GPU training

Hi,

Is there any way to train S4 with multiple GPUs? I have 2 GPUs, but only one of them is working.

Thanks

Resuming suspended training

Hey folks,

Congrats on the amazing and inspiring work!!

I have a quick question - how do I resume training + wandb logging if the training got terminated before completion. E.g. say the original command was CUDA_VISIBLE_DEVICES=0 python -m train experiment=s4-lra-pathx loader.batch_size=16 trainer.accumulate_grad_batches=2 and the training got suspended halfway through the full training. What command should I run to resume the training and the wandb logging?

Thanks in advance,
Ankit

minGPT like training

Hi! Very impressive results!

I'd like to try applying the S4 model on some toy example from NLP, like text generation (replicate examples in minGPT from @karpathy). I'm not very familiar with state-space models, so I don't understand a few things and have few questions:

  1. As far as I understand, such model doesn't need positional encodings/embeddings?
  2. How to properly train such a model in causal mode, that is, so that the model doesn't look into the future? Is there some equivalent to masking in Transformers? Or it's default mode out of box (like in vanilla rnn)?

Thanks!

Comparison of S4 with stateful transformers

Hi,

Wu et al. recently published a paper on Memorizing Transformers (transformers with states/memory), which extends their perceptive field to unbounded contexts (https://www.youtube.com/watch?v=5AoOpFFjW28&list=PL0NRmB0fnLJQJ3fuIk3yVULtm6_JnQ_zI, https://arxiv.org/abs/2203.08913). I am curious to hear what you think about how S4/Sashimi might compare with this new transformer model. My hunch is that S4 might be theoretically similar if you use the exponential measure density.

Faster Cauchy Kernels?

Breathtaking work, absolutely amazing application of linear algebra. Beautiful.

A few questions.

Appendix of S4 article mentions "... implementation of S4 uses the naive O(NL) algorithm ..._ " and README.md mentions custom kernels.

Question 0. Did you benchmark the naive O(NL) against the custom kernel?

Question 1. Is the 60x speedup in Table 8 with naive O(NL) or custom kernel? Or is the custom kernel only used during training?

Question 2. How big a percentage of compute is spent on S4 compared to mlp/lnorm/others in generation mode?

Apologies for any misunderstandings

Experiment reproduction issue with updated modules

Hi, I was trying to reproduce some of your results using the SaShiMi model by running the command

python -m train experiment=sashimi-sc09 wandb=null

but I get the error

TypeError: __init__() got an unexpected keyword argument 'pool'

due to the DownPool class no longer needing pool parameter for initialization.

Can I ask if there are any plans to fix these issues so that they work with the current implementations of the different modules?

Error Running Basic Test Script with v2 Tag (Works With v1)

Hi there,

I recently tried upgrading my S4 setup / environment to be on the v2 tag but ran into the following issue when running the basic test script:

(base) ray@test-python:~/state-spaces$ python -m train wandb=null pipeline=mnist model=s4
CONFIG
├── train
│   └── seed: 0                                                                                                                                                                                        
│       interval: epoch                                                                                                                                                                                
│       monitor: val/accuracy                                                                                                                                                                          
│       mode: max                                                                                                                                                                                      
│       ema: 0.0                                                                                                                                                                                       
│       test: false                                                                                                                                                                                    
│       debug: false                                                                                                                                                                                   
│       ignore_warnings: false                                                                                                                                                                         
│       state:                                                                                                                                                                                         
│         mode: null                                                                                                                                                                                   
│         chunk_len: null                                                                                                                                                                              
│         overlap_len: null                                                                                                                                                                            
│         n_context: 0                                                                                                                                                                                 
│         n_context_eval: 0                                                                                                                                                                            
│       sweep: null                                                                                                                                                                                    
│       group: null                                                                                                                                                                                    
│       benchmark_step: false                                                                                                                                                                          
│       benchmark_step_k: 1                                                                                                                                                                            
│       benchmark_step_T: 1                                                                                                                                                                            
│       checkpoint_path: null                                                                                                                                                                          
│       visualizer: filters                                                                                                                                                                            
│       disable_dataset: false                                                                                                                                                                         
│                                                                                                                                                                                                      
├── wandb
│   └── None                                                                                                                                                                                           
├── trainer
│   └── gpus: 1                                                                                                                                                                                        
│       accumulate_grad_batches: 1                                                                                                                                                                     
│       max_epochs: 200                                                                                                                                                                                
│       gradient_clip_val: 0.0                                                                                                                                                                         
│       log_every_n_steps: 10                                                                                                                                                                          
│       limit_train_batches: 1.0                                                                                                                                                                       
│       limit_val_batches: 1.0                                                                                                                                                                         
│       weights_summary: top                                                                                                                                                                           
│       progress_bar_refresh_rate: 1                                                                                                                                                                   
│       track_grad_norm: -1                                                                                                                                                                            
│       resume_from_checkpoint: null                                                                                                                                                                   
│                                                                                                                                                                                                      
├── loader
│   └── batch_size: 50                                                                                                                                                                                 
│       num_workers: 4                                                                                                                                                                                 
│       pin_memory: true                                                                                                                                                                               
│       drop_last: true                                                                                                                                                                                
│       train_resolution: 1                                                                                                                                                                            
│       eval_resolutions:                                                                                                                                                                              
│       - 1                                                                                                                                                                                            
│                                                                                                                                                                                                      
├── dataset
│   └── _name_: mnist                                                                                                                                                                                  
│       permute: true                                                                                                                                                                                  
│       val_split: 0.1                                                                                                                                                                                 
│       seed: 42                                                                                                                                                                                       
│                                                                                                                                                                                                      
├── task
│   └── _name_: base                                                                                                                                                                                   
│       loss: cross_entropy                                                                                                                                                                            
│       metrics:                                                                                                                                                                                       
│       - accuracy                                                                                                                                                                                     
│       torchmetrics: null                                                                                                                                                                             
│                                                                                                                                                                                                      
├── optimizer
│   └── _name_: adamw                                                                                                                                                                                  
│       lr: 0.001                                                                                                                                                                                      
│       weight_decay: 0.0                                                                                                                                                                              
│                                                                                                                                                                                                      
├── scheduler
│   └── _name_: plateau                                                                                                                                                                                
│       mode: max                                                                                                                                                                                      
│       factor: 0.2                                                                                                                                                                                    
│       patience: 20                                                                                                                                                                                   
│       min_lr: 0.0                                                                                                                                                                                    
│                                                                                                                                                                                                      
├── encoder
│   └── linear                                                                                                                                                                                         
├── decoder
│   └── _name_: sequence                                                                                                                                                                               
│       mode: pool                                                                                                                                                                                     
│                                                                                                                                                                                                      
├── model
│   └── layer:                                                                                                                                                                                         
│         _name_: s4                                                                                                                                                                                   
│         d_state: 64                                                                                                                                                                                  
│         channels: 1                                                                                                                                                                                  
│         bidirectional: false                                                                                                                                                                         
│         activation: gelu                                                                                                                                                                             
│         postact: null                                                                                                                                                                                
│         hyper_act: null                                                                                                                                                                              
│         dropout: 0.0                                                                                                                                                                                 
│         measure: legs                                                                                                                                                                                
│         rank: 1                                                                                                                                                                                      
│         dt_min: 0.001                                                                                                                                                                                
│         dt_max: 0.1                                                                                                                                                                                  
│         trainable:                                                                                                                                                                                   
│           dt: true                                                                                                                                                                                   
│           A: true                                                                                                                                                                                    
│           P: true                                                                                                                                                                                    
│           B: true                                                                                                                                                                                    
│         lr: 0.001                                                                                                                                                                                    
│         length_correction: true                                                                                                                                                                      
│         tie_state: true                                                                                                                                                                              
│         hurwitz: true                                                                                                                                                                                
│         resample: false                                                                                                                                                                              
│         deterministic: false                                                                                                                                                                         
│         l_max: 784                                                                                                                                                                                   
│         verbose: false                                                                                                                                                                               
│       _name_: model                                                                                                                                                                                  
│       prenorm: false                                                                                                                                                                                 
│       transposed: true                                                                                                                                                                               
│       n_layers: 4                                                                                                                                                                                    
│       d_model: 256                                                                                                                                                                                   
│       residual: R                                                                                                                                                                                    
│       pool:                                                                                                                                                                                          
│         _name_: sample                                                                                                                                                                               
│         pool: 1                                                                                                                                                                                      
│         expand: 1                                                                                                                                                                                    
│       norm: layer                                                                                                                                                                                    
│       dropout: 0.0                                                                                                                                                                                   
│                                                                                                                                                                                                      
└── callbacks
    └── learning_rate_monitor:                                                                                                                                                                         
          logging_interval: epoch                                                                                                                                                                      
        timer:                                                                                                                                                                                         
          step: true                                                                                                                                                                                   
          inter_step: false                                                                                                                                                                            
          epoch: true                                                                                                                                                                                  
          val: true                                                                                                                                                                                    
        params:                                                                                                                                                                                        
          total: true                                                                                                                                                                                  
          trainable: true                                                                                                                                                                              
          fixed: true                                                                                                                                                                                  
        model_checkpoint:                                                                                                                                                                              
          monitor: val/accuracy                                                                                                                                                                        
          mode: max                                                                                                                                                                                    
          save_top_k: 1                                                                                                                                                                                
          save_last: true                                                                                                                                                                              
          dirpath: checkpoints/                                                                                                                                                                        
          filename: val/accuracy                                                                                                                                                                       
          auto_insert_metric_name: false                                                                                                                                                               
          verbose: true                                                                                                                                                                                
                                                                                                                                                                                                       
Global seed set to 0
[2022-05-25 13:40:50,814][__main__][INFO] - Instantiating callback <src.callbacks.timer.Timer>
[2022-05-25 13:40:50,815][__main__][INFO] - Instantiating callback <src.callbacks.params.ParamsLog>
[2022-05-25 13:40:50,816][__main__][INFO] - Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
[2022-05-25 13:40:50,848][torch.distributed.nn.jit.instantiator][INFO] - Created a temporary directory at /tmp/tmpm51hqe7x
[2022-05-25 13:40:50,849][torch.distributed.nn.jit.instantiator][INFO] - Writing /tmp/tmpm51hqe7x/_remote_module_non_sriptable.py
Error executing job with overrides: ['wandb=null', 'pipeline=mnist', 'model=s4']
Traceback (most recent call last):
  File "/home/ray/state-spaces/train.py", line 553, in main
    train(config)
  File "/home/ray/state-spaces/train.py", line 498, in train
    trainer.fit(model)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 768, in fit
    self._call_and_handle_interrupt(
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1172, in _run
    self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1492, in _call_setup_hook
    self._call_lightning_module_hook("setup", stage=fn)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1593, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ray/state-spaces/train.py", line 74, in setup
    self.model = utils.instantiate(registry.model, self.hparams.model)
  File "/home/ray/state-spaces/src/utils/config.py", line 99, in instantiate
    return obj()
  File "/home/ray/state-spaces/src/models/sequence/model.py", line 69, in __init__
    block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, layer=layer, residual=residual, norm=norm, pool=pool)
  File "/home/ray/state-spaces/src/models/sequence/block.py", line 36, in __init__
    self.layer = utils.instantiate(registry.layer, layer, d_input)
  File "/home/ray/state-spaces/src/utils/config.py", line 99, in instantiate
    return obj()
  File "/home/ray/state-spaces/src/models/sequence/ss/s4.py", line 86, in __init__
    self.kernel = HippoSSKernel(self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args)
  File "/home/ray/state-spaces/src/models/sequence/ss/kernel.py", line 712, in __init__
    self.kernel = SSKernelNPLR(
  File "/home/ray/state-spaces/src/models/sequence/ss/kernel.py", line 217, in __init__
    self.C = nn.Parameter(_c2r(_resolve_conj(C)))
RuntimeError: view_as_real doesn't work on unresolved conjugated tensors.  To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.

Is this something you've seen before? I'd be happy to provide a fuller description of my package version, system architecture, etc. if you can let me know what might help get to the bottom of this bug.

Best,
Matthew

how to save and load checkpoints

Hi,
I'm running experiments with default settings using this command:

python -m train wandb=null experiment=sashimi-sc09

I'd like to get the checkpoint of the best model in the training process and then load it to generate .wav files. But i find the only checkpoint generated is a last.ckpt under the directory outputs/yyyy-mm-dd/xx-xx-xx/checkpoints, and i have no idea how to load it (since the checkpoint loading in sashimi/generation.py loads a .pt file)

Error when training on youtubemix

When running CUDA_VISIBLE_DEVICES=1,2,3,4,7 python -m train wandb=null experiment=sashimi-youtubemix dataset=youtubemix, I get the following error:

Traceback (most recent call last):
File "/data/al451/state-spaces/train.py", line 553, in main
train(config)
File "/data/al451/state-spaces/train.py", line 498, in train
trainer.fit(model)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
self._call_and_handle_interrupt(
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
self._dispatch()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
return self._run_train()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1311, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
self._evaluation_loop._reload_evaluation_dataloaders()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 170, in _reload_evaluation_dataloaders
self.trainer.reset_val_dataloader()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py", line 551, in reset_val_dataloader
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py", line 508, in _reset_eval_dataloader
if has_len_all_ranks(dataloader, self.training_type_plugin, module)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py", line 118, in has_len_all_ranks
raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: Total length of Dataloader across ranks is zero. Please make sure that it returns at least 1 batch.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Model can not converge on the LRA Pathfinder

Hi,

Thanks for the great work! When I ran your code on the LRA pathfinder dataset (using your config), I found it can't converge till the end of the 200th epoch as shown in the following log: loss=0.693, val/accuracy=0.499, val/loss=0.693, test/accuracy=0.495, test/loss=0.693, train/accuracy=0.501, train/loss=0.693. The loss is 0.693 throughout training.

Do you have any thoughts on this? Thanks!

Replication of Diffusion Results

Hi, I'm trying to replicate your results for applying SaShiMi in a diffusion context, and have run into some questions about implementation details along the way. It'd be awesome if you could help me out with them.

  1. I have found the diffusion version of the SaShiMi model at https://github.com/HazyResearch/state-spaces/blob/diffwave/sashimi/sashimi.py. I assume that one is the reference implementation. If yes, what parameters did you use? Just bidirectional=True, unet=True, diffwave=True and set the rest to the values specified in Appendix C.2.2 of the paper and their respective default values?
  2. In the original model, you use mu-law quantization for the model. Is this something you also use with the diffusion implementation? And are you using an embedding encoder & sequence decoder like for the AR model? If so, how are you implementing this setup, also in regards to e.g. the additive noise?

Best,
Stefan

SaShiMi generation script errors out with own models

Hey, first of all, great work with the repository, I don't think I've worked with a repository for a paper that's so extensive and well-structured so far.

I'm currently trying to train the SaShiMi model on my own dataset (following your guide here: #23), and I run into some issues when trying to generate samples with the trained model.
In case this is relevant, I'm trying to do inference on the checkpoint files, and I changed the number of layers (model.n_layers) to 4 to accommodate for the memory limitations of my GPU. Apart from that, I have done no changes to any of the training and model (code) except for switching the dataset to my own.
When I try to call the generation.py script now, I run into a range of errors:

  • The config overrides cause some errors, namely the hurwitz parameter does not exist anymore, and the setup_step methods don't seem to correctly accept (or rather pass them downstream) the mode argument. I "fixed" this by removing the hurwitz argument override and by adding the mode argument to all module.setup_step() methods and just passing it downstream as required.
  • Additionally, setting model.layer.postact=null causes the state_dict to not load successfully anymore, giving me the following error:
Missing key(s) in state_dict: "model.c_layers.0.layer.output_linear.weight", "model.c_layers.0.layer.output_linear.bias", "model.c_layers.2.layer.output_linear.weight", "model.c_layers.2.layer.output_linear.bias", "model.c_layers.4.layer.output_linear.weight", "model.c_layers.4.layer.output_linear.bias", "model.c_layers.6.layer.output_linear.weight", "model.c_layers.6.layer.output_linear.bias", "model.u_layers.0.1.layer.output_linear.weight", "model.u_layers.0.1.layer.output_linear.bias", "model.u_layers.0.3.layer.output_linear.weight", "model.u_layers.0.3.layer.output_linear.bias", "model.u_layers.0.5.layer.output_linear.weight", "model.u_layers.0.5.layer.output_linear.bias", "model.u_layers.0.7.layer.output_linear.weight", "model.u_layers.0.7.layer.output_linear.bias", "model.u_layers.1.1.layer.output_linear.weight", "model.u_layers.1.1.layer.output_linear.bias", "model.u_layers.1.3.layer.output_linear.weight", "model.u_layers.1.3.layer.output_linear.bias", "model.u_layers.1.5.layer.output_linear.weight", "model.u_layers.1.5.layer.output_linear.bias", "model.u_layers.1.7.layer.output_linear.weight", "model.u_layers.1.7.layer.output_linear.bias". 
Unexpected key(s) in state_dict: "model.c_layers.0.layer.output_linear.0.weight", "model.c_layers.0.layer.output_linear.0.bias", "model.c_layers.2.layer.output_linear.0.weight", "model.c_layers.2.layer.output_linear.0.bias", "model.c_layers.4.layer.output_linear.0.weight", "model.c_layers.4.layer.output_linear.0.bias", "model.c_layers.6.layer.output_linear.0.weight", "model.c_layers.6.layer.output_linear.0.bias", "model.u_layers.0.1.layer.output_linear.0.weight", "model.u_layers.0.1.layer.output_linear.0.bias", "model.u_layers.0.3.layer.output_linear.0.weight", "model.u_layers.0.3.layer.output_linear.0.bias", "model.u_layers.0.5.layer.output_linear.0.weight", "model.u_layers.0.5.layer.output_linear.0.bias", "model.u_layers.0.7.layer.output_linear.0.weight", "model.u_layers.0.7.layer.output_linear.0.bias", "model.u_layers.1.1.layer.output_linear.0.weight", "model.u_layers.1.1.layer.output_linear.0.bias", "model.u_layers.1.3.layer.output_linear.0.weight", "model.u_layers.1.3.layer.output_linear.0.bias", "model.u_layers.1.5.layer.output_linear.0.weight", "model.u_layers.1.5.layer.output_linear.0.bias", "model.u_layers.1.7.layer.output_linear.0.weight", "model.u_layers.1.7.layer.output_linear.0.bias".

Does this mean that I should rename those keys manually (there's a fairly clear correspondence) to make it work after changing the activation?

  • Finally, even when I pass through the mode parameter in module.setup_step(), I still get this error:
Traceback (most recent call last):
  File "/home/debaumas/state-spaces/sashimi/generation.py", line 192, in main
    module.setup_step(mode='dense')
  File "/home/debaumas/state-spaces/src/models/sequence/ss/kernel.py", line 1038, in setup_step
    self.kernel.setup_step(mode=mode)
  File "/home/debaumas/state-spaces/src/models/sequence/ss/kernel.py", line 515, in setup_step
    dC = torch.linalg.solve(
torch._C._LinAlgError: linalg.solve: (Batch element 0): The diagonal element 1 is zero, the solve could not be completed because the input matrix is singular.

Do you have any idea what might be causing this and maybe an idea about how to fix/circumvent this?

It'd be awesome if you could help point me in the right direction with this.

Best,
Stefan

request to load logs for experiments

Hi, can you please upload the logs of the experiments that were reported in the paper?

I tried to reproduce the wikitext-103 experiment but had to change some configurations due to hardware constraints. Even though the changes were minor, the results were not as I expected them to be. I think that the logs from the original experiments might help me to reproduce the results more easily.

Thank you so much!

WikiText-103

Hi, I'm interested in recreating your WikiText-103 LM experiment. Is it possible you could make that easier for me? Thanks! CJ

Conceptual Questions regarding S4/HiPPO

Dear authors/contributors,

First of all, thank you so much for publishing such a great work. I think it is really inspirational and we will see this model (or its variants) being deployed to solve variety of real-world problems in the next years.

I tried to go through your most recent papers starting from HiPPO, and I would like to kindly ask conceptual questions to deepen my understanding. As I couldn't find different sources of information other than your papers (and a couple of your recorded talks on Youtube and Annotated S4), I think this could be an appropriate place to ask those questions. If you prefer any other discussion platform, please let me know.

PS: These questions turned out to be a bit longer than I intended, but I don't expect you to clarify them all at once :)

  1. A matrix <—> Polynomial basis : From my understanding about your HiPPO paper, you derive the A matrix for various measures and polynomial bases. Therefore, for a given A (and hence given polynomial basis), we know how to reconstruct the original signal u(t) based on the state/coefficients x(t). My question is: What does the model learn when we initialize A as HiPPO but train it over time (i.e. when A is not fixed)? In other words, how does the polynomial basis change in that sense and how does the model have the ability to reconstruct the original signal u(t) with varying A?
  2. Learning the step size: In annotated S4, the step size is another parameter that’s learned through the training. (I am not sure if you do the same as I couldn't go over your code yet)
    • May I ask the intuition for why we learn this step size and what is its potential effect(s)? For instance, if we use a measure that is exponentially decaying over time, can we say that larger step size leads to prioritizing more recent history and smaller step size is better for giving more weight to a distant past (because its weight will decay smaller)?
    • If we work on a signal that has a natural sense of time (i.e. ECG signal) should we still make step size trainable (in first and all the intermediate layers) since the actual formulation (to my understanding) has no notion of the units of step size (e.g. seconds or days etc.)?
  3. Irregular sampling of time series. I am convinced by the continuous-time view of S4 that it can naturally handle the irregularly-sampled time series of an underlying continuous dynamics. However, I am confused by the discretization step where we leverage convolution for training and recurrence for fast inference. If I have an irregular time series, how can I train S4?
    • Small comment: I think if the training data is regularly sampled, we can still handle irregular time series in real-time inference based on the bilinear transform of A_bar, B_bar etc. into their continuous equivalent. Is that true?
  4. The effect of "deep" S4 layers. In Figure 2 of your paper “Efficiently Modeling Long Sequences with Structured State Spaces”, we see the visualization of the kernels for Path-X task for the first and last layers. We see that (mostly) first layers are for local context vs. last layers are for more global context. Why is it the case if HiPPO offers continuous-time memorization? In other words, why can’t it memorize the distant past in the first layers and why does it need stacking more layers to aggregate more context from the past? I assume it is related to a chosen measure and/or the step size itself, but I am really curious about your opinion.
    • For deep CNN-related models, we have the explanation that the receptive field grows with stacking more and more layers. (Field grows exponentially with dilated convolutions like TCN, and linearly for some other types). Is there any analogy or similar explanation for S4?

It is a great pleasure for me to know more about your exciting work. Many thanks in advance. I would be also happy to know if there are other resources that you can suggest.

How to visualize the results of prediction?

Hi,

Thank you for your sharing.

I am doing some experiments on Ett dataset using S4.
I wonder how to load the best checkpoint, test it with new data, and then visualize the prediction and truth.

Is there any easy way by pytorch lighting?

Cheers,
Max

can't run train.py w/o compiling Cauchy kernel for CUDA

Dear all,

I am having trouble compiling the Cauchy kernel, and although I have installed pykeops, running train.py always results in errors like this:

_RuntimeError: [KeOps] This KeOps shared object has been compiled without cuda support:

  1. to perform computations on CPU, simply set tagHostDevice to 0
  2. to perform computations on GPU, please recompile the formula with a working version of cuda._

The only thing that fixed the issues for me is commenting out the following try/catch. Without that (sorry for its uglyness...) the code never did default back to the slow kernel... now it does, but that is certainly not the right way for me to go about it ;)

I wonder if the try/catch-phrase needs to check whether the kernel actually runs, not just lets itself be imported?

''' try:
import pykeops
from src.models.functional.cauchy import cauchy_conj
has_pykeops = True
except ImportError:
has_pykeops = False
from src.models.functional.cauchy import cauchy_conj_slow
if not has_cauchy_extension:
log.error(
"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
)
'''

has_pykeops = False
from src.models.functional.cauchy import cauchy_conj_slow
if not has_cauchy_extension:
log.error(
"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
)

standalone S4 module usage

Hi! I really want to try S4D in my research, but first I want to make a little proof of concept for myself, without going into too much theory or detail.

Do I understand correctly that standalone models can be used out of the box? What bothers me: in your publications, you talk about the importance of model initializing and tuning optimizer parameters. Is proper initialization taken into account in standalone S4D? Is there anywhere to see how to properly set up the optimizer for standalone S4D?

Thanks for your research!

Shape '[]' is invalid for input of size

When I run:
python -m train wandb=null experiment=s4-cifar

I meet this problem:
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/torch/autograd/init.py", line 154, in backward
Variable._execution_engine.run_backward(
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/torch/autograd/function.py", line 199, in apply
return user_fn(self, *args)
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/pykeops/torch/generic/generic_red.py", line 263, in backward
grad = grad.reshape(
RuntimeError: shape '[1024, 2, 2, 32, 2]' is invalid for input of size 4202496

    Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

A has become time-invariant compared to HIPPO

Hi,

Thanks for your excellent work!
In the HIPPO paper, the transition matrix A of SSM is time-variant, but has become time-invariant in s4 model. I didn't find any theories/experiments that discuss why it still works.
Can you kindly share your thoughts?

Thanks again,
Ziwei

S4 Module Distribution

I think it would be quite useful if one could install this repository, either directly from git or preferably using pip.

I invision a change to models where people only have to:

  1. pip install state-spaces
  2. from state_spaces import S4
  3. replace use of nn.LSTM or nn.Transformer with S4

There is a package on pypi, not sure if it was pushed by you or not, but the latest version is older than the current code.
https://pypi.org/project/state-space/#history

GPU Out of Memory

I was wondering what parameters I could change to be able to run it on GPU with limited RAM. I tried reducing the layers to 4, which did not help. Also, it seems like batch size is set to 1 by default. I am using 4x TITAN RTX 24GB.

Error while using DataParrallel

Hello,
Thanks for sharing this awesome work, when i try to run example.py i get the following error :

CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running python setup.py install. This should speed up end-to-end training by 10-50% Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency. cuda ==> Preparing cifar10 data.. Files already downloaded and verified Files already downloaded and verified Files already downloaded and verified ==> Building model.. Optimizer group 0 | 28 tensors 0it [00:11, ?it/s] | 0/200 [00:00<?, ?it/s] Epoch: 0: 0%| | 0/200 [00:12<?, ?it/s] Traceback (most recent call last): File "example.py", line 373, in <module> train() File "example.py", line 312, in train outputs = model(inputs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply output.reraise() File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/_utils.py", line 425, in reraise raise self.exc_type(msg) RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "example.py", line 196, in forward z, _ = layer(z) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data1/ameen/state-spaces/src/models/sequence/ss/standalone/s4.py", line 842, in forward k = self.kernel(L=L) # (C H L) (B C H L) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data1/ameen/state-spaces/src/models/sequence/ss/standalone/s4.py", line 752, in forward k = self.kernel(L=L) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data1/ameen/state-spaces/src/models/sequence/ss/standalone/s4.py", line 427, in forward C = _r2c(self.C) RuntimeError: Output 3 of BroadcastBackward is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
Any idea how to solve this?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.