Coder Social home page Coder Social logo

jax-rl-template's Introduction

jax-rl-template

Code style: black

A minimal JAX-based reinforcement learning template, for rapidly spinning up RL projects!

All training and evaluation is JIT-compiled end-to-end in JAX. The template is for Python 3.8.12, built on top of:

  • JAX - Autograd and XLA
  • Flax - Neural network library
  • Optax - Gradient-based optimisation
  • Distrax - Probability distributions
  • Weights & Biases - Experiment tracking and visualisation

Features

Variants of this template are released as branches of this repository, each with different features:

Branch Description Agents Environments
main (here) Basic training and evaluation functionality (e.g. training loop, logging, checkpointing), plus common online RL agents PPO, SAC, DQN Gymnax
offline (TBC) Adds offline RL functionality (e.g. replay buffer, offline training) CQL, EDAC -

This template is designed to provide only core functionality, providing a solid foundation for RL projects. Whilst it is not designed to be a full-featured RL library, please raise an issue if you think a feature is missing that would be useful for many projects.

Setup

Running locally (CPU)

  1. Install Python packages from requirements-base.txt and requirements-cpu.txt in setup with:
cd setup && pip install $(cat requirements-base.txt requirements-cpu.txt)
  1. Sign into WandB to enable logging:
wandb login

Running via Docker

  1. Build the Docker container with the provided script:
cd setup/docker && ./build.sh
  1. Add your WandB key to the setup/docker folder:
echo <wandb_key> > setup/docker/wandb_key

Automatic code formatting

Install the Black pre-commit hook, after installing Python packages, with:

pre-commit install

This will check and fix formatting errors when you commit code.

Usage

Training locally

To train an agent, run:

python train.py <arguments>

For example, to train a PPO agent on the CartPole-v1 environment and log to WandB, run:

python train.py --agent ppo --env_name CartPole-v1 --log --wandb_entity wandb_username --wandb_project project_name

To see all possible arguments, see experiments/parse_args.py or run:

python train.py --help

Training via Docker

Launch training runs inside your built container with:

./run_docker.sh <gpu_id> python3 train.py <arguments>

For example, to train a DQN agent on the Asterix-MinAtar environment using GPU 3, run:

./run_docker.sh 3 python3 train.py --agent dqn --env_name Asterix-MinAtar

Acknowledgements

Large parts of the training loop and PPO implementation are based on PureJaxRL, which contains high-performance, single-file implementations of RL agents in JAX.

jax-rl-template's People

Contributors

emptyjackson avatar

Stargazers

Ziqian Zhang avatar Sam Coward avatar  avatar Yoon, Seungje avatar Hany Hamed avatar

Watchers

Kostas Georgiou avatar  avatar

jax-rl-template's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.