Coder Social home page Coder Social logo

neko's Introduction

Neko: a Library for Exploring Neuromorphic Learning Rules

Update:

add batch trainer: the trainer can directly read inputs from torch dataloader without to load .npy files. It saves a lot memory.

Paper

https://arxiv.org/abs/2105.00324

Installation

git clone https://github.com/byin-cwi/neko.git
cd neko
pip install -e .

Code Example

Train a RSNN with ALIF neurons with e-prop on MNIST:

from neko.backend import pytorch_backend as backend
# from neko.datasets import MNIST
from neko.evaluator import Evaluator
from neko.layers import ALIFRNNModel
from neko.learning_rules import Eprop
from neko.trainers import Trainer
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
# if not exist, download mnist dataset
train_set = dset.MNIST(root='./', train=True, transform=trans, download=True)
test_set = dset.MNIST(root='./', train=False, transform=trans, download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

model = ALIFRNNModel(128, 10, backend=backend, task_type='classification', return_sequence=False)
evaluated_model = Evaluator(model=model, loss='categorical_crossentropy', metrics=['accuracy', 'firing_rate'])
algo = Eprop(evaluated_model, mode='symmetric')
trainer = Trainer(algo)
trainer.train(train_loader,input_size=[28,28], T = 28, n_classes=10, epochs=30)

Example Scripts

Learning with e-prop

Training on the MNIST dataset with the same setting above, but more options available. For example, you can learn with BPTT and three variations of e-prop.

python examples/mnist.py

Training on the TIMIT dataset. You need to place the timit_processed folder the same place as the script containing the processed dataset produced by a script from the original authors of e-prop.

python examples/timit.py

Regularization enabled:

python timit.py --reg --eprop_mode symmetric --reg_coeff 5e-7
# Test: {'loss': 0.8918977379798889, 'accuracy': 0.7501428091397849, 'firing_rate': 12.973159790039062}

Faster training (~7.5X, 28s per epoch with RTX3090) with regularization enabled:

python timit.py --reg --eprop_mode symmetric --reg_coeff 3e-8 --batch_size 256  --learning_rate 0.01
# Test: {'loss': 0.8605409860610962, 'accuracy': 0.7542506720430108, 'firing_rate': 13.105131149291992}

Probabilistic learning with HMC

Training on the MNIST-1D dataset with HMC:

python examples/mnist_1d_hmc.py

Analogue Neural Network Training with Manhattan Rule

Training on the MNIST dataset with the simple Manhattan rule or Mahattan material rule:

python examples/mnist_manhattan.py

Gradient Comparison Tool

Compare the gradients from BPTT with the three varients of e-prop:

python examples/mnist_gradcompare.py

This is a visualization from the results of the script above.

neko's People

Contributors

byin-cwi avatar water-vapor avatar levinas avatar

Stargazers

Tijs Maas avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.