Coder Social home page Coder Social logo

marcometer / episodic-transformer-memory-ppo Goto Github PK

View Code? Open in Web Editor NEW
133.0 4.0 16.0 24.43 MB

Clean baseline implementation of PPO using an episodic TransformerXL memory

License: MIT License

Python 100.00%
pytorch deep-reinforcement-learning episodic-memory ppo transformer proximal-policy-optimization on-policy policy-gradient pomdp actor-critic transformer-xl gtrxl gated-transformer-xl trxl memory-gym

episodic-transformer-memory-ppo's Introduction

TransformerXL as Episodic Memory in Proximal Policy Optimization

This repository features a PyTorch based implementation of PPO using TransformerXL (TrXL). Its intention is to provide a clean baseline/reference implementation on how to successfully employ memory-based agents using Transformers and PPO.

Features

  • Episodic Transformer Memory
    • TransformerXL (TrXL)
    • Gated TransformerXL (GTrXL)
  • Environments
    • Proof-of-concept Memory Task (PocMemoryEnv)
    • CartPole
      • Masked velocity
    • Minigrid Memory
      • Visual Observation Space 3x84x84
      • Egocentric Agent View Size 3x3 (default 7x7)
      • Action Space: forward, rotate left, rotate right
    • MemoryGym
      • Mortar Mayhem
      • Mystery Path
      • Searing Spotlights (WIP)
  • Tensorboard
  • Enjoy (watch a trained agent play)

Citing this work

@article{pleines2023trxlppo,
  title = {TransformerXL as Episodic Memory in Proximal Policy Optimization},
  author = {Pleines, Marco and Pallasch, Matthias and Zimmer, Frank and Preuss, Mike},
  journal= {Github Repository},
  year = {2023},
  url = {https://github.com/MarcoMeter/episodic-transformer-memory-ppo}
}

Contents

Installation

Install PyTorch 1.12.1 depending on your platform. We recommend the usage of Anaconda.

Create Anaconda environment:

conda create -n transformer-ppo python=3.7 --yes
conda activate transformer-ppo

CPU:

conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cpuonly -c pytorch

CUDA:

conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge

Install the remaining requirements and you are good to go:

pip install -r requirements.txt

Train a model

The training is launched via train.py. --config specifies the path to the yaml config file featuring hyperparameters. The --run-id is used to distinguish training runs. After training, the trained model will be saved to ./models/$run-id$.nn.

python train.py --config configs/minigrid.yaml --run-id=my-trxl-training

Enjoy a model

To watch an agent exploit its trained model, execute enjoy.py. Some pre-trained models can be found in: ./models/. The to-be-enjoyed model is specified using the --model flag.

python enjoy.py --model=models/mortar_mayhem_grid_trxl.nn

Episodic Transformer Memory Concept

transformer-xl-model

Hyperparameters

Episodic Transformer Memory

Hyperparameter Description
num_blocks Number of transformer blocks
embed_dim Embedding size of every layer inside a transformer block
num_heads Number of heads used in the transformer's multi-head attention mechanism
memory_length Length of the sliding episodic memory window
positional_encoding Relative and learned positional encodings can be used
layer_norm Whether to apply layer normalization before or after every transformer component. Pre layer normalization refers to the identity map re-ordering.
gtrxl Whether to use Gated TransformerXL
gtrxl_bias Initial value for GTrXL's bias weight

General

gamma Discount factor
lamda Regularization parameter used when calculating the Generalized Advantage Estimation (GAE)
updates Number of cycles that the entire PPO algorithm is being executed
n_workers Number of environments that are used to sample training data
worker_steps Number of steps an agent samples data in each environment (batch_size = n_workers * worker_steps)
epochs Number of times that the whole batch of data is used for optimization using PPO
n_mini_batch Number of mini batches that are trained throughout one epoch
value_loss_coefficient Multiplier of the value function loss to constrain it
hidden_layer_size Number of hidden units in each linear hidden layer
max_grad_norm Gradients are clipped by the specified max norm

Schedules

These schedules can be used to polynomially decay the learning rate, the entropy bonus coefficient and the clip range.

learning_rate_schedule The learning rate used by the AdamW optimizer
beta_schedule Beta is the entropy bonus coefficient that is used to encourage exploration
clip_range_schedule Strength of clipping losses done by the PPO algorithm

Add Environment

Follow these steps to train another environment:

  1. Implement a wrapper of your desired environment. It needs the properties observation_space, action_space and max_episode_steps. The needed functions are render(), reset() and step.
  2. Extend the create_env() function in utils.py by adding another if-statement that queries the environment's "type"
  3. Adjust the "type" and "name" key inside the environment's yaml config

Note that only environments with visual or vector observations are supported. Concerning the environment's action space, it can be either discrte or multi-discrete.

Tensorboard

During training, tensorboard summaries are saved to summaries/run-id/timestamp.

Run tensorboad --logdir=summaries to watch the training statistics in your browser using the URL http://localhost:6006/.

Results

Every experiment is repeated on 5 random seeds. Each model checkpoint is evaluated on 50 unknown environment seeds, which are repeated 5 times. Hence, one data point aggregates 1250 (5x5x50) episodes. Rliable is used to retrieve the interquartile mean and the bootstrapped confidence interval. The training is conducted using the more sophisticated DRL framework neroRL. The clean GRU-PPO baseline can be found here.

Mystery Path Grid (Goal & Origin Hidden)

mpg_off_results

TrXL and GTrXL have identical performance. See Issue #7.

Mortar Mayhem Grid (10 commands)

mmg_10_results

episodic-transformer-memory-ppo's People

Contributors

marcometer avatar reytuag avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

episodic-transformer-memory-ppo's Issues

Question context window

Hi @MarcoMeter,

I'm working on implementing PPO TrXL in JAX and have a question regarding the way you handle the context window.

If I understood correctly, in your code, you compute the attention with only the last step embedding as querry and with all the previous timestep embedding being cached (without gradient).

While for the collection of rollouts ("evaluation") this is perfect, for the learning part, this seems to differ from the original transformerXL paper : https://arxiv.org/abs/1901.02860
Where there is a cached segment (as you have ) but the current segment is still computing the embedding more than 1 timestep.
image

Also the Ada paper (Human-Timescale Adaptation in an Open-Ended Task Space) : https://arxiv.org/abs/2301.07608 , mentions a context window for the learning (of size 80,which i assume might correspond to the size of the window where gradient propagates ) and a cached memory of size 300.
image

I made a small ugly sketch from the figure of transformerXL paper to maybe illustrate what it may change. For example in the case of a part of the context window without caching (left), gradient (red arrows) from loss of steps 6 would propagate the gradient to give information on "how $h_5^1$ should be computed". (i didn't display all the gradient path) . Whereas in your case (right) as the step 6 only uses the cached version of $h_5^1$ there is no such information propagation. So in your case, the embeddings of time t only get gradient information with the loss associated with timestep t. While having a context window without caching of size >1 allows to have gradient information on the computation of the embedding of timestep t from loss of timestep t+1 etc. And this might help to learn embedding at previous steps that are useful to attend at later steps.
image

I hope this makes sense, feel free to ask me to clarify. Also this might not be important in practice or might even hurt performance (also it is more costly) . But the learning rollout context window in the Ada paper of 80 and cached memory of size 300 made me question things. And maybe this is the classic way to use transformerXL in RL.

Best regards

Atari env

Hello, I attempted to set up an Atari environment and train using your current code, but unfortunately, I was unable to learn anything during training. Could you suggest any possible explanations for this and recommend specific hyperparameters that I could experiment with for Pong?

Does this repo implement a WIP paper?

Hi,

I want to ask if this repo served as code for a paper of yours or if it implements some paper. The architecture seems different from a normal recurrent agent.

help: Regarding support issues for multi-agent environments

Hello, I am very interested in your work. I want to try using it in my experiment. But I don't know how to make it support multi-agent environments. I am a beginner, so my expression may not be accurate. Please forgive me.

In my environment, each step will return observations and rewards from multiple agents. The number of intelligent agents is fixed during the training process, but can be changed arbitrarily during testing. It seems that it can be seen as returning multiple batches within a step.

So how should I modify the code to perfectly integrate with your provided TransformerXL

Extending the code to support image + vector observations

Hi, thank you so much for developing this repository.
I am looking forward to applying the code to my custom environment which takes in a dict observation (following gym specifications) containing a vector observation and an RGB observation.

I was wondering if i modify the code and append the vector representation to the image embeddings produced by CNN feature extractor and make relevant changes to the buffer, would it be enough to support an environment with image+vector observations?

I would appreciate any insights you might have on this.

Cheers!

Question about attention when mask is all zero

My project was inspired by this repo, thanks for your work!

However, there is a question that when mask is all zero, e.g., at the environment reset timestep, the energy is all -inf (-1e20) and the weights in attention are all equal after the softmax. How to deal with this problem? In other words, when there is no history info, how could we represent the history info?

Looking forward to your reply!

Question: Number of layers of the position wise MLP in transformer block.

Hi,
Thanks for the well written code !
I was wondering if you've explored the impact of the number of layers in the position wise MLP in the transformer block. Because if i'm not mistaken, in most of the implementation i saw (like https://github.com/kimiyoung/transformer-xl/tree/master which is cited in the Stabilizing transformer for RL : https://arxiv.org/abs/1910.06764 ) and even in the original transformer paper, they use a MLP with 2 layers and a ReLu between them.

So i was wondering if your choice of having a single layer with a ReLU ( in TransformerBlock : self.fc = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())) is due to empirical tests you've done or works i'm not aware of ?
I'm not aware of any work that study the impact of the architecture of the position wise MLP in the transformer block. Which i guess might be hard to do properly as for example adding a layer changes the total number of parameters.

Memory indices in enjoy.py

Hi,
Thanks for the clear implementation, this repo is very useful for my project.

However, seeing different mean reward in eval ( based on enjoy.py) and train, I dug a little bit into the enjoy.py code.
And it seems like the memory indices created in init_transformer_memory are wrong.

First the shape of memory_indices seems to be wrong as for a memory_length of 16 and a max_episode_steps of 256 i get a memory_indices of shape [496,16] instead of [256,16] if i'm not mistaken. And looking more closely into it, the "repetitions" variable, in this case, is of size (255,16) instead of (15,16) (so (max_episode_steps-1,memory_length ) instead of (memory_length-1,memory_length)) , leading to the transformer attending to the very first steps for the whole eval hence the worse performance.

I did a pull request with the fix to have the right memory indices and it seems to work fine. :

I replaced : repetitions = torch.repeat_interleave(torch.arange(0, trxl_conf["memory_length"]).unsqueeze(0), max_episode_steps - 1, dim = 0).long()
By : repetitions = torch.repeat_interleave(torch.arange(0, trxl_conf["memory_length"]).unsqueeze(0), trxl_conf["memory_length"] - 1, dim = 0).long()

ReLU in residual connections?

Hi,

I am using part of your code for a particular implementation of a transformer architecture I need as part of my master thesis research in RL. I noticed on the original paper from (Parisotto et al., 2019) that they re-order the LayerNorms so they place them at the input of both the multihead-attention and the feed-forward sub-modules. I saw that you also implement this on your code, via a the config["layer_norm"] setting. But on the paper they also mention, I quote: "Because the layer norm reordering causes a path where two linear layers are applied in sequence, we apply a ReLU activation to each sub-module output before the residual connection (see Appendix C for equations).". In fact, on those equations they apply a ReLU both to the output of the multihead-attention and feed-forward sub-modules, before performing the residual connection. I did not see that specific step on your code (just the standard residual connection), so I wonder whether there is a particular reason for that, or maybe I am missing something (I'm still quite novice in these implementations). In any case, congratulations for your great works, it is helping me a lot to understand the inner workings of such architectures. Thanks!

Inquiry on Research Publication

First of all, I want to express my sincere gratitude for sharing your code, which has been immensely helpful to me, particularly the concept of integrating Transformers with reinforcement learning. During my exploration of your work, I encountered some architectures that I found difficult to fully comprehend. I would like to ask if you have published any papers on this subject. I am interested in conducting more in-depth research and discussing these issues with you.

Questions about the implementation

Hi Marco,

I understand the need for the memory for the first transformer block, but I would like some help understanding the need for later memory blocks.

Given the first block containing past episodes, you can sample a bunch of episodes and use them as the input to update the policy and the value function without remembering intermediate results (such as the input for the second transformer block of those episodes in the past).

Another thing is for the first memory block, why don't you store raw observations instead of the embedded observations? A potential issue is that when updating the policy and the value function, the embedding layers will be updated (which can be very different from those used during data collection). This might make the off-policyness worse.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.