Coder Social home page Coder Social logo

thomas-bouvier / distributed-continual-learning Goto Github PK

View Code? Open in Web Editor NEW
0.0 3.0 0.0 919 KB

Towards Rehearsal-based Continual Learning at Scale: distributed CL with Horovod + PyTorch

License: MIT License

Python 99.06% Shell 0.94%
continual-learning data-parallelism deep-learning experience-replay hpc ptychography rehearsal

distributed-continual-learning's Introduction

distributed-continual-learning

This is a PyTorch + Horovod implementation of the continual learning experiments with deep neural networks described in the following article:

Continual learning approaches implemented here are based on rehearsal, which is delegated in a separate high-performance C++ backend Neomem.

This repository primarily supports experiments in the academic continual learning setting, whereby a classification-based problem is split up into multiple, non-overlapping tasks that must be learned sequentially (class-incremental scenario). Instance-incremental scenarios are supported too.

Some Python code has been inspired by the mammoth and convNet.pytorch repositories.

Installation

The current version of the code has been tested with Python 3.10 with the following package versions:

  • pytorch 2.2
  • timm 0.9.2
  • horovod 0.28.1
  • continuum 1.2.7
  • nvidia-dali-cuda110 1.27.0 (optional)

Make sure to install Neomem to benefit from global sampling of representatives. Simlink it using ln -s ../neomem cpp_loader. If not available, this code will fallback to a Python, local only, low performance rehearsal buffer implementation.

Further Python packages used are listed in requirements.txt. Assuming Python and pip are set up, these packages can be installed using:

pip install -r requirements.txt

In an HPC environment, we strongly advise to use Spack to manage dependencies.

Usage

Parameters defined in the config.yaml override CLI parameters. However, values for backbone_config, buffer_config, tasksets_config will be concatenated with those defined by CLI, instead of override them ;

Values for optimizer_regime will override regimes defined by backbone/ in Python.

Parameter name Required Description Possible values
--backbone Yes DL backbone model to instanciate mnistnet, resnet18, resnet50, mobilenetv3, efficientnetv2, convnext, ghostnet, ptychonn
--backbone-config Backbone-specific parameters "{'lr': 0.01, 'lr_min': 1e-6, }"
--model Default: Vanilla Continual Learning strategy Vanilla, Er, Agem, Der, Derpp
--model-config Reset strategies and CL model-specific parameters "{'reset_state_dict': True}" allows to reset the model internal state between tasks
"{'alpha': 0.2}" is needed for Der model
"{'alpha': 0.2, 'beta': 0.8}" are needed for Derpp model
--buffer-config Rehearsal buffer parameters "{'rehearsal_ratio': 20}" sets the proportion of the input dataset to be stored in the rehearsal buffer
--tasksets-config Scenario configuration, as defined in the continuum package Class-incremental scenario with 2 tasks: "{'scenario': 'class', 'initial_increment': 5, 'increment': 5}"
Instance-incremental scenario with 2 tasks: "{'scenario': 'instance', 'num_tasks': 5}"
"{'concatenate_tasksets': True}" allows to concatenate previous tasksets before next task
--dataset Dataset mnist, cifar10, cifar100, tinyimagenet, imagenet, imagenet_blurred, ptycho

WandB sweeps

To run a hyperparameter search, first adapt the sweep.py (located in this directory) file if needed. Then, configure your optimization objective in sweep.yaml.

Make sure you exported your WandB API key before running anything export WANDB_API_KEY=key and set WANDB_MODE=run. Once you are ready, execute the sweep_launcher.sh <hostname> <wandb_project> <sweep_conf> [<existing_sweep_id>] script on the master machine, providing the following parameters:

  • hostname: the address of the current machine e.g., chifflot-7.lille.grid5000.fr:1
  • wandb_project: the name of an existing W&B project where the run will be saved
  • sweep_conf: the name of a sweep config defined in sweep.py

To stop a sweep run, go to the online WandB dashboard and click "Stop run". To stop the whole sweep process, ps aux | grep agent on the machine and kill the process, then ps aux | grep wandb and kill that process too.

Continual Learning Strategies

Specific implementations have to be selected using --buffer-config "{'implementation': <implementation>}". ER with implementation standard was used in the paper.

Approach Name Available Implementations
Experience Replay (ER) Er standard, flyweight, python
Averaged (A-GEM) Agem python
Dark Experience Replay (DER) Der standard, flyweight, python
Dark Experience Replay ++ (DER++) Derpp standard, flyweight, python

Baselines

From Scratch

python main.py --backbone <backbone_model> --dataset <dataset> --model Vanilla --model-config "{'reset_state_dict': True}" --tasksets-config "{<..tasksets-config, 'concatenate_tasksets': True>}"

Incremental

python main.py --backbone <backbone_model> --dataset <dataset> --model Vanilla --tasksets-config "{<tasksets-config>}"

Examples

Deep learning

Usual deep learning can be done using this project. Model Vanilla will be instanciated by default:

python main.py --backbone mnistnet --dataset mnist
python main.py --backbone resnet18 --dataset cifar100
python main.py --backbone resnet50 --dataset tinyimagenet
python main.py --backbone efficientnetv2 --dataset imagenet_blurred

Continual learning

python main.py --backbone mnistnet --dataset mnist --tasksets-config "{'scenario': 'class', 'initial_increment': 5, 'increment': 5}"
python main.py --backbone resnet18 --dataset cifar10 --tasksets-config "{'scenario': 'class', 'initial_increment': 4, 'increment': 3}"
python main.py --backbone resnet18 --model Er --dataset cifar100 --tasksets-config "{'scenario': 'instance', 'num_tasks': 5}"
python main.py --backbone resnet18 --model Der --buffer-config "{'rehearsal_ratio': 20}" --dataset cifar10 --tasksets-config "{'scenario': 'class', 'initial_increment': 4, 'increment': 3}"
python main.py --backbone resnet18 --model Derpp --buffer-config "{'rehearsal_ratio': 20}" --dataset imagenet100small --tasksets-config "{'scenario': 'class', 'initial_increment': 40, 'increment': 30}"
python main.py --backbone resnet50 --model Agem --dataset tinyimagenet --tasksets-config "{'scenario': 'instance', 'num_tasks': 5}"

distributed-continual-learning's People

Contributors

thomas-bouvier avatar xelon-coder avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.