Coder Social home page Coder Social logo

intuitive-robots / beso Goto Github PK

View Code? Open in Web Editor NEW
41.0 3.0 6.0 522.64 MB

[RSS 2023] Official code for "Goal Conditioned Imitation Learning using Score-based Diffusion Policies"

Home Page: https://intuitive-robots.github.io/beso-website/

License: MIT License

Python 99.54% Shell 0.46%
diffusion-models imitation-learning robotics score-based-generative-models

beso's Introduction

Beso

Paper, Project Page, RSS 2023

Moritz Reuss1, Maximilian Li1, Xiaogang Jia1, Rudolf Lioutikov1

1Intuitive Robots Lab, Karlsruhe Institute of Technology

Official code for "Goal Conditioned Imitation Learning using Score-based Diffusion Policies"

Installation Guide

First create a conda environment using the following command

sh install.sh

During this process two additional packages will be installed:

To add relay_kitchen environment to the PYTHONPATH run the following commands:

conda develop <path to your relay-policy-learning directory>
conda develop <path to your relay-policy-learning directory>/adept_envs
conda develop <path to your relay-policy-learning directory>/adept_envs/adept_envs

Dataset

To download the dataset for the Relay Kitchen and the Block Push environment from the given link and repository, and adjust the data paths in the franka_kitchen_main_config.yaml and block_push_main_config.yaml files, follow these steps:

  1. Download the dataset: Go to the link from play-to-policy and download the dataset for the Relay Kitchen and Block Push environments.

  2. Unzip the dataset: After downloading, unzip the dataset file and store it.

  3. Adjust data paths in the configuration files:

Open the ./configs/franka_kitchen_main_config.yaml and set the data_path argument to <path_to_dataset>/relay_kitchen. Open the ./configs/block_push_main_config.yaml and set the data_path argument to <path_to_dataset>/multimodal_push_fixed_target After adjusting the data paths in both configuration files, you should be ready to use the datasets in the respective environments.


Code Overview

drawing

All configurations are managed using configs from hydra, a hierarchical configuration manager. Each configuration component is represented as a separate file within a lower-level folder. The main configuration file for each environment can be found under the /configs directory. To train and test an agent, you need to run the desired config file within the scripts/training.py method. The repo uses WandB to log and monitor all runs, please add your wandb account to the config together with a project name.
Hydra needs absolute paths, so please change the default paths in your code /path/to/beso to your local path.

We provide two environments, the Franka Kitchen and the Block Push environment with goal-conditioned extensions. These two are based on the code of play-to-policy. To switch the environment, such as from the Franka kitchen to the block pushing, in the scripts/training.py code, you need to replace the following code snippet:

@hydra.main(config_path="configs", config_name="franka_kitchen_main_config.yaml")
def main(cfg: DictConfig) -> None:

to

@hydra.main(config_path="configs", config_name="block_push_config.yaml")
def main(cfg: DictConfig) -> None:

An overview of the individual classes used to implement the model is depicted below:

The workspace class manages the environment and dataset related to the task at hand. The agent class encapsulates the model and training algorithm, serving as a wrapper to test the policy on the environment.


Train an agent

There exist the general training.py file to train a novel agent and evaluate its performance after the training process. A new agent can be trained using the following command:

[beso]$ conda activate play 
(play)[beso]$ python scripts/training.py 

To train the CFG-variant of BESO change the following parameter:

[beso]$ conda activate play 
(play)[beso]$ python scripts/training.py cond_mask_prob=0.1

We can easily train the agent on 10 seeds sequentially by using:

[beso]$ conda activate play 
(play)[beso]$ python scripts/training.py --multirun seed=1,2,3,4,5,6,7,8,9,10

Please note, that we are using wandb to log the training of our model in this repo. Thus one need to adjust, the wandb variable in the main config file with your wandb entity and project name.


Understanding BESO

drawing

An overview of the action generation process of BESO is visualized above. New actions are generated during rollouts with the predict method of the beso_agent. Inside predict we call the sample_loop to start the action denoising process. Depending on the chosen sampler, the required function from beso/agents/diffusion_agents/k_diffusion/gc_sampling.py is called. This process is visualized on the left part of the Figure. The pre-conditioning wrapper of the denoising model, as shown in the middle, is implemented in the GCDenosiser class beso/agents/diffusion_agents/k_diffusion/score_wrappers.py. The score-transformer is then called from beso/agents/diffusion_agents/k_diffusion/score_gpts.py as shown in the right part of the figure.

Evaluation

We provide several pre-trained models for testing under trained_models. If you want to evaluate a model and change its inference parameters you can run the following script:

python scripts/evaluate.py

To change parameters for diffusion sampling, check out out configs/evaluate_kitchen and configs/evaluate_blocks, where there is an detailed overview of all inference parameters. Below we provide an overview of important parameters and the available implementations for each.

BESO sampling customizations

BESO is based on the continuous-time score diffusion model of Karras et al. 2022, which allows to adapt several hyperameters for fast sampling. Below is an overview of the parameters, which can be optimzied on tasks for further performance improvement:

Number of Denoising Steps

We can control the number of denoising steps by adapting the parameter: n_timesteps.

Sampler

One can easily swap the used sampler for BESO, by changing the related parameter in the config agents.beso_agent.sampler_type with one of the following options:

  • 'ddim' sample_ddim
  • 'heun' sample_heun
  • 'euler_ancestral' sample_euler_ancestral
  • 'euler' sample_euler
  • 'dpm' sample_dpm_2
  • 'ancestral' sample_dpm_2_ancestral
  • 'dpmpp_2s' sample_dpmpp_2s
  • 'dpmpp_2s_ancestral' sample_dpmpp_2s_ancestral
  • 'dpmpp_2m' sample_dpmpp_2m
  • 'dpmpp_2m_sde' sample_dpmpp_2m_sde

The most robust sampler in our experiments has been the DDIM sampler. We found, that the Euler Ancestral excels in the kitchen environment. To start with, we recommend to use the DDIM sampler.

Time steps

There exist several implementations of time steps schedulers from common diffusion frameworks:

The exponential or linear scheduler worked best for our experiments. However, its worth to try out all samplers on a new environment to get the best performance.


Acknowledgements

This repo relies on the following existing codebases:

  • The goal-conditioned variants of the environments are based on play-to-policy.
  • The inital environments are adapted from Relay Policy Learning, IBC and BET.
  • The continuous time diffusion model is adapted from k-diffusion together with all sampler implementations.
  • the score_gpt class is adapted from miniGPT.
  • A few samplers are have been imported from dpm-solver

Citation

@inproceedings{
    reuss2023goal,
    title={Goal Conditioned Imitation Learning using Score-based Diffusion Policies},
    author={Reuss, Moritz and Li, Maximilian and Jia, Xiaogang and Lioutikov, Rudolf},
    booktitle={Robotics: Science and Systems},
    year={2023}
}

beso's People

Contributors

mbreuss avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

beso's Issues

one question about the goal of inferring

Thank you for your excellent work.
I have a question about the 'goal'. During the training process, we have offline training data, so we can easily obtain the goal. However, during the inference process, how do we obtain the goal for the inference task? Does the simulation environment itself provide the goal, or are there other ways to generate the goal?

Questions about the implementation of classifier-free guidance

Hi,

Thanks for the nicely organized codebase!
I have a question about your implementation of classifier-free guidance.
In line 366 of score_gpts.py, it seems that when learning the unconditional policy, the goal is only masked out partly because the Bernoulli distribution applies to (bs, t, d) instead of (bs,).
However, during inference, when calculating the unconditional probability, the goal would be a completely zero tensor according to line 302 of the same file.
I'm wondering if this is the actual implementation you use for the paper results, and if so, what would be the intuition that masking partly during training can still make the diffusion model learn the unconditional policy.

Question about data scaling

Hi, I've been playing around with this repo and noticed that, at least for my datasets, using the regular scaler completely breaks the model, even if sigma data is adjusted to 1. Additionally, the sigma data produced by the min max scaler is closer to 0.33 but you have chosen 0.5. I was wondering if you could explain your thought process behind the choice of the scaler and sigma data value?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.