Coder Social home page Coder Social logo

diffusion-policies-for-offline-rl's Introduction

Diffusion Policies for Offline RL โ€” Official PyTorch Implementation

Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning
Zhendong Wang, Jonathan J Hunt and Mingyuan Zhou
https://arxiv.org/abs/2208.06193

Abstract: Offline reinforcement learning (RL), which aims to learn an optimal policy using a previously collected static dataset, is an important paradigm of RL. Standard RL methods often perform poorly at this task due to the function approximation errors on out-of-distribution actions. While a variety of regularization methods have been proposed to mitigate this issue, they are often constrained by policy classes with limited expressiveness that can lead to highly suboptimal solutions. In this paper, we propose representing the policy as a diffusion model, a recent class of highly-expressive deep generative models. We introduce Diffusion Q-learning (Diffusion-QL) that utilizes a conditional diffusion model for behavior cloning and policy regularization. In our approach, we learn an action-value function and we add a term maximizing action-values into the training loss of the conditional diffusion model, which results in a loss that seeks optimal actions that are near the behavior policy. We show the expressiveness of the diffusion model-based policy, and the coupling of the behavior cloning and policy improvement under the diffusion model both contribute to the outstanding performance of Diffusion-QL. We illustrate the superiority of our method compared to prior works in a simple 2D bandit example with a multimodal behavior policy. We further show that our method can achieve state-of-the-art performance on the majority of the D4RL benchmark tasks for offline RL.

Experiments

Requirements

Installations of PyTorch, MuJoCo, and D4RL are needed. Please see the requirements.txt for environment set up details.

Running

Running experiments based our code could be quite easy, so below we use walker2d-medium-expert-v2 dataset as an example.

For reproducing the optimal results, we recommend running with 'online model selection' as follows. The best_score will be stored in the best_score_online.txt file.

python main.py --env_name walker2d-medium-expert-v2 --device 0 --ms online --lr_decay

For conducting 'offline model selection', run the code below. The best_score will be stored in the best_score_offline.txt file.

python main.py --env_name walker2d-medium-expert-v2 --device 0 --ms offline --lr_decay --early_stop

Hyperparameters for Diffusion-QL have been hard coded in main.py for easily reproducing our reported results. Definitely, there could exist better hyperparameter settings. Feel free to have your own modifications.

Citation

If you find this open source release useful, please cite in your paper:

@article{wang2022diffusion,
  title={Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning},
  author={Wang, Zhendong and Hunt, Jonathan J and Zhou, Mingyuan},
  journal={arXiv preprint arXiv:2208.06193},
  year={2022}
}

diffusion-policies-for-offline-rl's People

Contributors

daihuiao avatar zhendong-wang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.