Coder Social home page Coder Social logo

mmcaulif / cardio Goto Github PK

View Code? Open in Web Editor NEW
4.0 2.0 1.0 2.6 MB

Cardio reduces boilerplate code by providing simple, lightweight environment interaction loops to make implementing deep reinforcement learning algorithms easy, intuitive, and readable.

License: Apache License 2.0

Python 98.89% Makefile 0.45% JavaScript 0.65%

cardio's Introduction

πŸƒ Cardio: Runners for Deep Reinforcement Learning in Gym Environments πŸƒ

Ruff Pythonver Docformatter Style License

Motivation | Installation | Usage | Under the hood | Development | Contributing

So many reinforcement learning libraries, what makes Cardio different?

  • Easy and readable: Focus on the agent and leave the boilerplate code to Cardio
  • Extensible: Easy progression from simple algorithms all the way up to Rainbow and beyond
  • Research friendly: Cardio was designed to be a whiteboard for your RL research

Cardio aims to make new algorithm implementations easy to do, readable and framework agnostic by providing a collection of modular environment interaction loops for the research and implementation of deep reinforcement learning (RL) algorithms in Gymnasium environments. Out of the box these loops are capable of more complex experience collection approaches such as n-step transitions, trajectories, and storing of auxiliary values to a replay buffer. Accompanying these core components are helpful utilities (such as replay buffers and data transformations), and single-file reference implementations for state-of-the-art algorithms.

Motivation

In the spectrum of RL libraries, Cardio lies in-between large complete packages such as stable-baselines3 (lacks modularity/extensibility) that deliver complete implementations of algorithms, and more research-friendly repositories like CleanRL (repeating boilerplate code), in a similar design paradigm to Google’s Dopamine and Acme.

To achieve the desired structure and API, Cardio makes some concessions with the first of which being speed. There's no competing against end-to-end jitted implementations, but going down this direction greatly hinders the modularity and application of implementations to arbitrary environments. If you are interested in lightning quick training of agents on established baselines then please look towards the likes of Stoix.

Secondly, taking a modular approach leaves us less immediately extensible than the likes of CleanRL, despite the features in place to make the environment loops transparent, there is inevitably going to be edge cases where Cardio is not the best choice.

Installation

NOTE: Jax is a major requirement for runner internally, the installation process will be updated soon to make a better distinction between setting up Cardio using Jax for GPU's, CPU's or TPU's. For now please manually install whichever Jax version suits your environment best. By default we just show for cpu but swapping "cpu" out for "gpu" should work all the same.

Prerequisites (to be expanded):

  • Python == 3.10

Via pip with Jax cpu:

pip install cardio-rl[cpu]

To install is from source via:

git clone https://github.com/mmcaulif/Cardio.git
cd cardio
pip install .[cpu]

Alternatively you can install all requirements e.g. for testing, experimenting and development:

pip install -e .[dev,exp,cpu]

Or use the provided makefile (which also sets up the precommit hooks):

make install_cpu

Usage

Below is a simple exampls (using the CartPole environment) leveraging Cardio's off-policy runner to help write a simple implementation of the core deep RL, Deep Q-Networks. It will be assumed that you have an beginners understanding of deep RL and this section just serves to demonstrate how Cardio might fit into different algorithm implementations.

DQN

In this algorithm our agent performs a fixed number of environment steps (aka a rollout) and saves the transitions experienced in a replay buffer for performing update steps. Once the rollout is done, we sample from the replay buffer and pass the sampled transitions to the agents update method. To implement our agent we will use the provided Cardio Agent class and override the init, update and step methods:

class DQN(crl.Agent):
    def __init__(
        self,
        env: gym.Env,
        critic: nn.Module,
        gamma: float = 0.99,
        targ_freq: int = 1_000,
        optim_kwargs: dict = {"lr": 1e-4},
        init_eps: float = 0.9,
        min_eps: float = 0.05,
        schedule_len: int = 5000,
        use_rmsprop: bool = False,
    ):
        self.env = env
        self.critic = critic
        self.targ_critic = copy.deepcopy(critic)
        self.gamma = gamma
        self.targ_freq = targ_freq
        self.update_count = 0

        if not use_rmsprop:
            self.optimizer = th.optim.Adam(self.critic.parameters(), **optim_kwargs)
        else:
            # TODO: fix mypy crying about return type
            self.optimizer = th.optim.RMSprop(self.critic.parameters(), **optim_kwargs)

        self.eps = init_eps
        self.min_eps = min_eps
        self.ann_coeff = self.min_eps ** (1 / schedule_len)

    def update(self, batches):
        data = jax.tree.map(th.from_numpy, batches)
        s, a, r, s_p, d = data["s"], data["a"], data["r"], data["s_p"], data["d"]

        q = self.critic(s).gather(-1, a)
        q_p = self.targ_critic(s_p).max(dim=-1, keepdim=True).values
        y = r + self.gamma * q_p * ~d

        loss = F.mse_loss(q, y.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_count += 1
        if self.update_count % self.targ_freq == 0:
            self.targ_critic.load_state_dict(self.critic.state_dict())

        return {}

    def step(self, state):
        if np.random.rand() > self.eps:
            th_state = th.from_numpy(state)
            action = self.critic(th_state).argmax().numpy(force=True)
        else:
            action = self.env.action_space.sample()

        self.eps = max(self.min_eps, self.eps * self.ann_coeff)
        return action, {}

Next we instantiate our runner. When we instantiate a runner we will pass it our environment, our agent, rollout length, and the batch size, but there also other arguments you may want to tweak.

env = gym.make("CartPole-v1")
runner = crl.OffPolicyRunner(
    env=env,
    agent=DQN(env, Q_critic(4, 2)),
    rollout_len=4,
    batch_size=32,
)

And finally, to run 50,000 rollouts (in this case, 50,000 x 4 environment steps) and perform an agent update after each one, we just use the run method:

runner.run(rollouts=50_000, eval_freq=1_250)

Sprinter

components are likely to be better suited to Cardio's sibling repo, Sprinter acts as an extension of Cardio, applying the library to create a zoo of different algortihm implementations, and providing simple boilerplate code examples for research focussed tasks such as hyperparameter optimisation or benchmarking, and intended to be cloned/forked and built on as opposed to pip installed.

Sprinter is further behind in its development and currently just acts as a collection of modern algorithm implementations using Jax/Rlax/Flax/Optax.

If you are looking for some more practical examples of Cardio in use (or just an assortment of Jax algorithm implementations), components are likely to be better suited to Cardio's sibling repo, Sprinter should be all you need.

Under the hood

Below we'll go over the inner workings of Cardio. The intention was to make Cardio quite minimal and easy to parse, akin to Dopamine, but I hope it is interesting to practitioners and I'm eager to hear any feedback/opinions on the design paradigm. This section also serves to highlight a couple of the nuances of Cardio's components.

Diagram pending creation

Transition

Borrowing an idea from TorchRL, the core building block that Cardio centers around is a dictionary that represents an MDP transition. By default the transition dict has the following keys: s, a, r, s_p, d corresponding to state, action, reward, state' (state prime or next state) and done. Two important concepts to be aware of are:

  1. A Cardio Transition dictionary does not neccessarily correspond to a a single environment step. For example, in the case of n-step transitions s will correspond to s_t but s_p will correspnd to s_(t+n) with the reward key having n number of entries. Furthermore, the replay buffer stores data as a transition dictionary with keys pointing to multiple states, actions rewards etc.
  2. The done value used in Cardio is the result of the OR between the terminal and truncated values used in gymnasium. Empiraclly, decoupling termination and truncation has been shown to have a negligible affect. However, this is a trivial feature to change and its possible that leaving up to the user is best.

By using dictionaries, new entries are easy to add and thus the storing of user-defined variables (such as intrinsic reward or policy probabilities) is built in to the framework, whereas this would be nontrivial to implement in more abstract libraries like stable-baselines3.

Agent

Much like Acme the Cardio agent class is very minimal, simply defining some base methods that are used by the environment interaction loops. The most important thing to know is when they are called, what data is provided, and which component is calling it. The most important of which are the step (given a state, return an action and any extras), view (given a step transition, return any extras) and update methods (given a batch of transitions).

Gatherer

The gatherer is the primary component in Cardio and serves the purpose of stepping through the environment directly with a provided agent, or a random policy. The gatherer has two buffers that are used to package the transitions for the Runner in the desired manner. The step buffer collects transitions optained from singular environment steps and has a capacity equal to n. When the step buffer is full, it transforms its elements into one n-step transition and adds that transition to the transition buffer. Some rough pseudocode is provided below.

Gatherer pseudocode

The step buffer is emptied after terminal states to prevent transitions overlapping across episodes. When n > 1, the step buffer needs to be "flushed", i.e. create transitions from steps that would otherwise be thrown away. Please refer to the example below provided by my esteemed colleage, ChatGPT:

If you are collecting 3-step transitions, here's how you handle the transitions where s_3 is a terminal state:

  1. Transition from s_0: (s_0, a_0, [r_0, r_1, r_2], s_3)
  2. Transition from s_1: (s_1, a_1, [r_1, r_2], s_3)
  3. Transition from s_2: (s_2, a_2, r_2, s_3)

The transition buffer is even simpler, just containing the processed transitions from the step buffer. The transition buffer starts empty when the gatherer's step method is called and also maintains its data across terminal steps. Both of these characteristics are opposite to the step buffer which persists across gatherer.step calls but not across terminal steps.

Due to the nature of n-step transitions, sometimes the gatherer's transition buffer will have less transitions than environment steps taken (as the step buffer gets filled) and other times it will have more (when the step buffer gets flushed) but at any given time there will be a rough one-to-one mapping between environment steps taken and transitions collected. Lastly, rollout lengths can be less than n.

Runner

The runner is the high level orchestrator that deals with the different components and data, it contains a gatherer, your agent and any replay buffer you might have. The runner step function calls the gatherer's step function as part its own step function, or as part of its built in warmup (for collecting a large amount of initial data with your agent) and burnin (for randomly stepping through an environment, not collecting data, such as for initialising normalisation values) methods. The runner can either be used via its run method (which iteratively calls the runner.step and the agent.update methods) or just with its step method if you'd like more finegrained control.

Development

The main development goal for Cardio will be to make it as fast, easy to use, and extensible as possible. The aim is not to include many RL features or to cater to every domain. Far down the line I could imagine trying to incorporate async runners but that can get messy quickly. However, if you notice any bugs, or have any suggestions or feature requests, user input is greatly appreciated!

Some tentative tasks right now are:

  • Integrated loggers (WandB, Neptune, Tensorboard etc.)
  • Verify GymnasiumAtariWrapper works as intended and remove SB3 wrapper (removing SB3 as a requirement too).
  • Implement seeding for reproducability.
  • Widespread and rigorous testing!
  • Properly document Prioritised Buffer Implementation details
  • Explore alternatives to jax.tree.map

A wider goal is to perform profiling and squash any immediate performance bottlenecks. Wrapping an environment in a Cardio runner should introduce as little overhead as possible.

Any RL components are likely to be better suited to Cardio's sibling repo, Sprinter.

Contributing

Cat pull request image

Jokes aside, given the roadmap described above for Cardio, PR's related to bugs and performance are the main interest. If you would like a new feature, please create an issue first and we can discuss.

License

This repository is licensed under the Apache 2.0 License

cardio's People

Contributors

mmcaulif avatar manus-sony avatar

Stargazers

Sacha Chernyavskiy avatar Alexey Zemtsov avatar Mohamed Elsayed avatar Edan Toledo avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

edantoledo

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.