Coder Social home page Coder Social logo

nikhilbarhate99 / min-decision-transformer Goto Github PK

View Code? Open in Web Editor NEW
242.0 1.0 24.0 21.56 MB

Minimal implementation of Decision Transformer: Reinforcement Learning via Sequence Modeling in PyTorch for mujoco control tasks in OpenAI gym

License: MIT License

Python 100.00%
reinforcement-learning deep-reinforcement-learning deep-learning offline-reinforcement-learning pytorch pytorch-transformers transformer machine-learning openai-gym mujoco

min-decision-transformer's Introduction

Decision Transformer

Overview

Minimal code for Decision Transformer: Reinforcement Learning via Sequence Modeling for mujoco control tasks in OpenAI gym. Notable difference from official implementation are:

  • Simple GPT implementation (causal transformer)
  • Uses PyTorch's Dataset and Dataloader class and removes redundant computations for calculating rewards to go and state normalization for efficient training
  • Can be trained and the results can be visualized and rendered on google colab with the provided notebook

Results

Note: these results are mean and variance of 3 random seeds obtained after 20k updates (due to timelimits on GPU resources on colab) while the official results are obtained after 100k updates. So these numbers are not directly comparable, but they can be used as rough reference points along with their corresponding plots to measure the learning progress of the model. The variance in returns and scores should decrease as training reaches saturation.

Dataset Environment DT (this repo) 20k updates DT (official) 100k updates
Medium HalfCheetah 42.18 ± 00.59 42.60 ± 00.10
Medium Hopper 69.43 ± 27.34 67.60 ± 01.00
Medium Walker 75.47 ± 31.08 74.00 ± 01.40

Instructions

Mujoco-py

Install mujoco-py library by following instructions on mujoco-py repo

D4RL Data

Datasets are expected to be stored in the data directory. Install the D4RL repo. Then save formatted data in the data directory by running the following script:

python3 data/download_d4rl_datasets.py

Running experiments

  • Example command for training:
python3 scripts/train.py --env halfcheetah --dataset medium --device cuda
  • Example command for testing with a pretrained model:
python3 scripts/test.py --env halfcheetah --dataset medium --device cpu --num_eval_ep 1 --chk_pt_name dt_halfcheetah-medium-v2_model_22-02-13-09-03-10_best.pt

The dataset needs to be specified for testing, to load the same state normalization statistics (mean and var) that is used for training. An additional --render flag can be passed to the script for rendering the test episode.

  • Example command for plotting graphs using logged data from the csv files:
python3 scripts/plot.py --env_d4rl_name halfcheetah-medium-v2 --smoothing_window 5

Additionally --plot_avg and --save_fig flags can be passed to the script to average all values in one plot and to save the figure.

Note:

  1. If you find it difficult to install mujoco-py and d4rl then you can refer to their installation in the colab notebook
  2. Once the dataset is formatted and saved with download_d4rl_datasets.py, d4rl library is not required further for training.
  3. The evaluation is done on v3 control environments in mujoco-py so that the results are consistent with the decision transformer paper.

Citing

Please use this bibtex if you want to cite this repository in your publications:

@misc{minimal_decision_transformer,
    author = {Barhate, Nikhil},
    title = {Minimal Implementation of Decision Transformer},
    year = {2022},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/nikhilbarhate99/min-decision-transformer}},
}

References

min-decision-transformer's People

Contributors

nikhilbarhate99 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

min-decision-transformer's Issues

The calculation of state_mean and state_std in d4rl_info.py

Hello there, we are currently trying to use the code to give some other dataset a go, like walker2d-random-v2. So we would like to kindly ask how state_mean and state_std in d4rl_info are calculated, or whether there is any open data for these values. Thank you very much!

Debugging a custom gym environment

Hey,
I am trying to train this on my custom gym environment but the model isn't learning at all. Any idea what could be the probable cause?

The position of `dropout` operation are different from official repo ?

In official repo, Attention part:
`

    w = nn.Softmax(dim=-1)(w)
    w = self.attn_dropout(w)
    # Mask heads if we want to
    if head_mask is not None:
        w = w * head_mask

    outputs = [torch.matmul(w, v)]`

The dropout is directly after the Softmax and before the matmul.

On the other hand, in our implement:
`

    normalized_weights = F.softmax(weights, dim=-1)
    # attention (B, N, T, D)
    # normalized_weights.shape: (B, N, T, T)
    # v.shape: (B, N, T, D)
    attention = self.att_drop(normalized_weights @ v)`

The dropout is at last.

In my opinion, they are different. How do you think about it? :-)

Training is ok, but failed to eval.

Hello👋,

Thank you for open-sourcing the code for the min decision transformer. Your code has been tremendously helpful in helping me understand DT.

However, I am currently facing an issue. During the training process, the action loss is indeed steadily decreasing, but the test results have consistently been subpar, to the point of having no discernible impact. I've been grappling with this problem for a while now and can't seem to figure out why this is happening.
截屏2023-10-09 16 12

By the way, I haven't tested it on the three environments, namely halfcheetah, hopper, and walker2d, mainly because I've been struggling with configuring d4rl. I'm using the upgraded version of d4rl provided by Farama, specifically on the pointmaze offline dataset.

If you could spare some time to assist me with this, I would be immensely grateful!

oscillations of eval score

Hi nik!

I have trained the walker2d and other environments several times. The settings and hyper parameters are all followed by original DT. And I found some strange points. Most of the environments, DT has the best score when the training steps between 10000 and 20000, and no obvious increase after 20000 steps.Sometimes, a trough also happened during that period. Would you mind give me some clues about these things?
walker2d_50

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.