See my article
Trains autoregressive models to predict the bits of an encoded image file
Currently only works well with mnist and llama. Mamba is a WIP
llama-training-mnist.mp4
You will need python>=3.11 as well as a rust compiler. pip install -r requirements.txt
should do it. You can optionally install mamba from https://github.com/state-spaces/mamba to get access to the fast cuda implementation. This repo has a reference implementation of mamba in pure python/pytorch which is equivalent, but slower.
mkdir datasets/mnist/; cd mnist
huggingface-cli download clip-benchmark/wds_mnist --repo-type dataset --local-dir ./ --local-dir-use-symlinks False
v0.0.1 - Mnist digit generation