Coder Social home page Coder Social logo

twni2016 / memory-rl Goto Github PK

View Code? Open in Web Editor NEW
43.0 2.0 5.0 157 KB

When Do Transformers Shine in RL? Decoupling Memory from Credit Assignment, NeurIPS 2023 (oral)

Home Page: https://arxiv.org/abs/2307.03864

License: MIT License

Python 95.60% Jupyter Notebook 4.40%
credit-assignment memory pomdp transformer dqn

memory-rl's Introduction

Evaluating Memory and Credit Assignment in Memory-Based RL

This is the official code for the paper (Section 5.1 & 5.2: discrete control)

"When Do Transformers Shine in RL? Decoupling Memory from Credit Assignment", NeurIPS 2023 (oral)

by Tianwei Ni, Michel Ma, Benjamin Eysenbach, and Pierre-Luc Bacon.

Please switch to the branch to check the code for Section 5.3 (PyBullet continuous control).

Modular Design

The code has a modular design which requires three configuration files. We hope that such design could facilitate future research on different environments, RL algorithms, and sequence models.

  • config_env: specify the environment, with config_env.env_name specifying the exact (memory / credit assignment) length of the task
    • Passive T-Maze (this work)
    • Active T-Maze (this work)
    • Passive Visual Match (based on [Hung et al., 2018])
    • Key-to-Door (based on [Raposo et al., 2021])
  • config_rl: specify the RL algorithm and its hyperparameters
    • DQN (with epsilon greedy)
    • SAC-Discrete (we find --freeze_critic can prevent gradient explosion, see the discussion in Appendix C.1 in the latest version of the arXiv paper).
  • config_seq: specify the sequence model and its hyperparameters including training sequence length config_seq.sampled_seq_len and number of layers --config_seq.model.seq_model_config.n_layer
    • LSTM [Hochreiter and Schmidhuber, 1997]
    • Transformer (GPT-2) [Radford et al., 2019]

Installation

We use python 3.7+ and list the basic requirements in requirements.txt.

Reproducing the Results

Below are example commands to reproduce the main results shown in Figure 3 and 6. For the ablation results, please adjust the corresponding hyperparameters.

To run Passive T-Maze with a memory length of 50 with LSTM-based agent:

python main.py \
    --config_env configs/envs/tmaze_passive.py \
    --config_env.env_name 50 \
    --config_rl configs/rl/dqn_default.py \
    --train_episodes 20000 \
    --config_seq configs/seq_models/lstm_default.py \
    --config_seq.sampled_seq_len -1 \

To run Passive T-Maze with a memory length of 1500 with Transformer-based agent:

python main.py \
    --config_env configs/envs/tmaze_passive.py \
    --config_env.env_name 1500 \
    --config_rl configs/rl/dqn_default.py \
    --train_episodes 6700 \
    --config_seq configs/seq_models/gpt_default.py \
    --config_seq.sampled_seq_len -1 \

To run Active T-Maze with a memory length of 20 with Transformer-based agent:

python main.py \
    --config_env configs/envs/tmaze_active.py \
    --config_env.env_name 20 \
    --config_rl configs/rl/dqn_default.py \
    --train_episodes 40000 \
    --config_seq configs/seq_models/gpt_default.py \
    --config_seq.sampled_seq_len -1 \
    --config_seq.model.seq_model_config.n_layer 2 \
    --config_seq.model.seq_model_config.n_head 2 \

To run Passive Visual Match with a memory length of 60 with Transformer-based agent:

python main.py \
    --config_env configs/envs/visual_match.py \
    --config_env.env_name 60 \
    --config_rl configs/rl/sacd_default.py \
    --shared_encoder --freeze_critic \
    --train_episodes 40000 \
    --config_seq configs/seq_models/gpt_cnn.py \
    --config_seq.sampled_seq_len -1 \

To run Key-to-Door with a memory length of 120 with LSTM-based agent:

python main.py \
    --config_env configs/envs/keytodoor.py \
    --config_env.env_name 120 \
    --config_rl configs/rl/sacd_default.py \
    --shared_encoder --freeze_critic \
    --train_episodes 40000 \
    --config_seq configs/seq_models/lstm_cnn.py \
    --config_seq.sampled_seq_len -1 \
    --config_seq.model.seq_model_config.n_layer 2 \

To run Key-to-Door with a memory length of 250 with Transformer-based agent:

python main.py \
    --config_env configs/envs/visual_match.py \
    --config_env.env_name 250 \
    --config_rl configs/rl/sacd_default.py \
    --shared_encoder --freeze_critic \
    --train_episodes 30000 \
    --config_seq configs/seq_models/gpt_cnn.py \
    --config_seq.sampled_seq_len -1 \
    --config_seq.model.seq_model_config.n_layer 2 \
    --config_seq.model.seq_model_config.n_head 2 \

The train_episodes of each task is specified in budget.py.

By default, the logging data will be stored in logs/ folder with csv format. If you use --debug flag, it will be stored in debug/ folder.

Logging and Plotting

After the logging data is stored, you can plot the learning curves and aggregation plots (e.g., Figure 3 and 6) using vis.ipynb jupyter notebook.

We also provide our logging data used in the paper shared in google drive (< 400 MB).

Acknowledgement

The code is largely based on prior works:

Questions

If you have any questions, please raise an issue (preferred) or send an email to Tianwei ([email protected]).

memory-rl's People

Contributors

twni2016 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

memory-rl's Issues

Movement Penalty in TMaze

Hi, thanks for releasing this! I'm curious about the reward function used in the main TMaze results. Maybe I'm misunderstanding the way the active/passive versions disentangle credit assignment.

It seems like the goal of both environments is to unit test whether a Transformer can learn to recall information from the first (passive) or third (active, after going backwards first) observation at the final timestep. The penalty is meant to remove the sparse exploration problem of navigating to the goal position. You implement that penalty as rew = float(x < time_step - oracle_length) * penalty which is active when the policy falls behind pace of the optimal policy. Since the policy has no way to make up that pace, it's penalized from the first timestep it disagrees with the optimal policy. Wouldn't we be testing the same recall ability on the final timestep if the penalty was instead rew = float(action != right and time_step > oracle_length) * penalty? Both provide dense signal to move towards the goal, but the second version can provide that signal in any sub-optimal episode while the original version basically seems like the kryptonite of epsilon greedy and forces an unusually low-epsilon schedule?

Port for Minigrid Environments

Hello and congrats for the nice paper!

I am working on my master's thesis trying out your code as my codebase. However, I integrated and use the Minigrid environments instead of the custom mazes you created.

Now, if I understand correctly, the concept of "Memory" refers to the amount of timesteps our policy is fed with data (for each batch).

If I use a Minigrid environment, should the code theoretically work as is?
If not, what changes should I do?

Thanks in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.