Coder Social home page Coder Social logo

pydreamer's Introduction

PyDreamer

Reimplementation of DreamerV2 model-based RL algorithm in PyTorch.

The official DreamerV2 implementation can be found here.

This is a research project with no guarantees of stability and support. Breaking changes to be expected!

Features


50-step long "dream" sequences generated by the model from an initial state.

PyDreamer implements most of the features of DreamerV2, but is not an exact copy and there are some subtle differences. Here is a summary of features and differences.

DreamerV2 PyDreamer
Env - Discrete actions
Env - Continuous actions
Env - Multiple workers
Model - Categorical latents
Model - Policy entropy
Model - Layer normalization
Training - KL balancing
Training - Reinforce policy gradients
Training - Dynamics policy gradients
Training - Multistep value target TD-λ GAE
Training - State persistence (TBTT)
Training - Mixed precision
Training - Offline RL
Exploration - Plan2Explore
Data - Replay buffer In-memory Disk or cloud
Data - Batch sampling Random Full episodes
Metrics - Format Tensorboard Mlflow

PyDreamer also has some experimental features

PyDreamer
Multi-sample variational bound (IWAE)
Categorical reward decoder
Probe head for global map prediction

Environments

PyDreamer is set up to run out-of-the-box with the following environments. You should use the Dockerfile, which has all the dependencies set up, and then --configs defaults {env} to select one of the predefined configurations inside config/defaults.yaml.

Results

Atari benchmarks

Here is a comparison between PyDreamer and the official DreamerV2 scores on a few Atari environments:

The results seem comparable, though there are some important differences. These are most likely due to different default hyperparameters, and the different buffer sampling (random vs whole episodes)

DreamerV2 PyDreamer
gamma 0.999 0.99
train_every 16 ~42 (1 worker)
lr (model,actor,critic) (2e-4, 4e-5, 1e-4) (3e-4, 1e-4, 1e-4)
grayscale true false
buffer_size 2e6 10e6

Trainer vs worker speed

PyDreamer uses separate processes for environment workers, so the trainer and workers do not block each other, and the trainer can utilize GPU fully, while workers are running on CPU. That means, however, that there is no train_every parameter, and the ratio of gradient updates to environment steps will depend on the hardware used.

To give a rough idea, here is what I'm getting on NVIDIA T4 machine:

  • 1.4 gradient steps / sec
  • 60 agent steps / sec (single worker)
  • 240 env steps / sec (x4 action repeat)
  • 42 train_every (= agent steps / gradient steps)

On V100 you should be seeing ~3 gradient steps/sec, so effective train_every would be ~20. In that case it is probably best to increase number of workers (generator_workers) to accelerate training, unless you are aiming for maximal sample efficiency.

Running

Running locally

Install dependencies

pip install -r requirements.txt

If you want to use Atari environment, you need to get Atari ROMs

pip install atari-py==0.2.9
wget -L -nv http://www.atarimania.com/roms/Roms.rar
apt-get install unrar                                   # brew install unar (Mac)
unrar x Roms.rar                                        # unar -D Roms.rar  (Mac)
unzip ROMS.zip
python -m atari_py.import_roms ROMS
rm -rf Roms.rar *ROMS.zip ROMS

Run training (debug CPU mode)

python launch.py --configs defaults atari debug --env_id Atari-Pong

Run training (full GPU mode)

python launch.py --configs defaults atari atari_pong

Running with Docker

docker build . -f Dockerfile -t pydreamer
docker run -it pydreamer --configs defaults atari atari_pong

Running on Kubernetes

See scripts/kubernetes/run_pydreamer.sh

Configuration

All of the configuration is done via YAML files stored in config/*.yaml. PyDreamer automatically loads all YAML files it finds there, and when you specify --configs {section1} {section2} ... it takes a union of the sections with given names.

The typical usage is to specify --configs defaults {env_config} {experiment}, where

You can also override individual parameters with command line arguments, e.g.

python launch.py --configs defaults atari --env_id Atari-Pong --gamma 0.995

Mlflow Tracking

PyDreamer relies quite heavily on Mlflow tracking to log metrics, images, store model checkpoints and even replay buffer.

That does NOT mean you need to have a Mlflow tracking server installed. By default, mlflow is just a pip package, which stores all metrics and files locally under ./mlruns directory.

That said, if you are running experiments on the cloud, it is very convenient to set up a persistent Mlflow tracking server. In that case just set the MLFLOW_TRACKING_URI env variable, and all the metrics will be sent to the server instead of written to the local dir.

Note that the replay buffer is just a directory with mlflow artifacts in *.npz format, so if you set up an S3 or GCS mlflow artifact store, the replay buffer will be actually stored on the cloud and replayed from there! This makes it easy to persist data across container restarts, but be careful to store data in the same cloud region as the training containers, to avoid data transfer charges.

pydreamer's People

Contributors

jurgisp avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

pydreamer's Issues

Docker image build fail

Greetings.

Thank you very much for your reproduction of the original Dreamver-v2 code in Pytorch, as well as open sourcing it on Github.
I have been looking through various Torch ports of the Dreamer-v2 algorithm for some experiments in my research, and your implementation definitely caught my attention, being very complete.

While I managed to get some the PyDreamer agent run locally using the provided instructions, I encountered a problem during the build of the Docker image I though you might want to know about.

Namely, after executing docker build . -f Dockerfile -t pydreamer, it returns the following error:

(pydreamer) d055@akira:~/random/rl/pydreamer$ docker build . -f Dockerfile -t pydreamer
Sending build context to Docker daemon  330.2kB
Step 1/32 : ARG ENV=standard
Step 2/32 : FROM pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel AS base
 ---> c7e20104018e
Step 3/32 : RUN apt-get update && apt-get install -y     git xvfb python3.7-dev python3-setuptools     libglu1-mesa libglu1-mesa-dev libgl1-mesa-dev libosmesa6-dev mesa-utils freeglut3 freeglut3-dev     libglew2.0 libglfw3 libglfw3-dev zlib1g zlib1g-dev libsdl2-dev libjpeg-dev lua5.1 liblua5.1-0-dev libffi-dev     build-essential cmake g++-4.8 pkg-config software-properties-common gettext     ffmpeg patchelf swig unrar unzip zip curl wget tmux     && rm -rf /var/lib/apt/lists/*
 ---> Using cache
 ---> 2f7651b27698
Step 4/32 : FROM base AS standard-env
 ---> 2f7651b27698
Step 5/32 : RUN pip3 install atari-py==0.2.9
 ---> Using cache
 ---> f7bc1331fcbb
Step 6/32 : RUN wget -L -nv http://www.atarimania.com/roms/Roms.rar &&     unrar x Roms.rar &&     unzip ROMS.zip &&     python3 -m atari_py.import_roms ROMS &&     rm -rf Roms.rar ROMS.zip ROMS
 ---> Using cache
 ---> 5cd00f9e095d
Step 7/32 : RUN mkdir -p /root/.mujoco &&     cd /root/.mujoco &&     wget -nv https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz &&     tar -xf mujoco.tar.gz &&     rm mujoco.tar.gz
 ---> Using cache
 ---> b1decf2d8fd8
Step 8/32 : RUN pip3 install dm_control
 ---> Using cache
 ---> ab0179162e2a
Step 9/32 : FROM base AS dmlab-env
 ---> 2f7651b27698
Step 10/32 : RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" |     tee /etc/apt/sources.list.d/bazel.list &&     curl https://bazel.build/bazel-release.pub.gpg |     apt-key add - &&     apt-get update && apt-get install -y bazel
 ---> Using cache
 ---> 9851f7e81633
Step 11/32 : RUN git clone https://github.com/deepmind/lab.git /dmlab
 ---> Using cache
 ---> e8a16a62e893
Step 12/32 : WORKDIR /dmlab
 ---> Using cache
 ---> 56da3d82b379
Step 13/32 : RUN git checkout "937d53eecf7b46fbfc56c62e8fc2257862b907f2"
 ---> Using cache
 ---> cf2e70fb4e1a
Step 14/32 : RUN ln -s '/opt/conda/lib/python3.7/site-packages/numpy/core/include/numpy' /usr/include/numpy &&     sed -i '[email protected]@python3.7@g' python.BUILD &&     sed -i 's@glob(\[@glob(["include/numpy/\*\*/*.h", @g' python.BUILD &&     sed -i 's@: \[@: ["include/numpy", @g' python.BUILD &&     sed -i 's@650250979303a649e21f87b5ccd02672af1ea6954b911342ea491f351ceb7122@1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730@g' WORKSPACE &&     sed -i 's@rules_cc-master@rules_cc-main@g' WORKSPACE &&     sed -i 's@rules_cc/archive/master@rules_cc/archive/main@g' WORKSPACE &&     bazel build -c opt python/pip_package:build_pip_package --incompatible_remove_legacy_whole_archive=0
 ---> Running in 189631cf8219
Extracting Bazel installation...
Starting local Bazel server and connecting to it...
Loading: 
Loading: 0 packages loaded
Analyzing: target //python/pip_package:build_pip_package (1 packages loaded, 0 targets configured)
Analyzing: target //python/pip_package:build_pip_package (7 packages loaded, 15 targets configured)
INFO: SHA256 (https://github.com/bazelbuild/rules_cc/archive/main.zip) = 3839996049629e6377abdfd04681ddeeb0cc3db13b9d2ff81bf46700cb4529f7
DEBUG: Rule 'rules_cc' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = "3839996049629e6377abdfd04681ddeeb0cc3db13b9d2ff81bf46700cb4529f7"
DEBUG: Repository rules_cc instantiated at:
  /dmlab/WORKSPACE:11:13: in <toplevel>
Repository rule http_archive defined at:
  /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
Analyzing: target //python/pip_package:build_pip_package (22 packages loaded, 213 targets configured)
INFO: SHA256 (https://github.com/abseil/abseil-cpp/archive/master.zip) = 6d33798883560650cb9484a915e5085d251b61c14d8937ad714448577786c0fa
DEBUG: Rule 'com_google_absl' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = "6d33798883560650cb9484a915e5085d251b61c14d8937ad714448577786c0fa"
DEBUG: Repository com_google_absl instantiated at:
  /dmlab/WORKSPACE:17:13: in <toplevel>
Repository rule http_archive defined at:
  /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
INFO: Repository jpeg_archive instantiated at:
  /dmlab/WORKSPACE:45:13: in <toplevel>
Repository rule http_archive defined at:
  /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
WARNING: Download from http://www.ijg.org/files/jpegsrc.v9c.tar.gz failed: class com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
ERROR: An error occurred during the fetch of repository 'jpeg_archive':
   Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl", line 111, column 45, in _http_archive_impl
                download_info = ctx.download_and_extract(
Error in download_and_extract: java.io.IOException: Error downloading [http://www.ijg.org/files/jpegsrc.v9c.tar.gz] to /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/jpeg_archive/temp13980778680031230739/jpegsrc.v9c.tar.gz: Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
ERROR: Error fetching repository: Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl", line 111, column 45, in _http_archive_impl
                download_info = ctx.download_and_extract(
Error in download_and_extract: java.io.IOException: Error downloading [http://www.ijg.org/files/jpegsrc.v9c.tar.gz] to /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/jpeg_archive/temp13980778680031230739/jpegsrc.v9c.tar.gz: Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
Analyzing: target //python/pip_package:build_pip_package (31 packages loaded, 2049 targets configured)
INFO: Repository glib_archive instantiated at:
  /dmlab/WORKSPACE:34:13: in <toplevel>
Repository rule http_archive defined at:
  /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
INFO: Repository png_archive instantiated at:
  /dmlab/WORKSPACE:64:13: in <toplevel>
Repository rule http_archive defined at:
  /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
ERROR: /dmlab/q3map2/BUILD:54:10: //q3map2:q3map2 depends on @jpeg_archive//:jpeg in repository @jpeg_archive which failed to fetch. no such package '@jpeg_archive//': java.io.IOException: Error downloading [http://www.ijg.org/files/jpegsrc.v9c.tar.gz] to /root/.cache/bazel/_bazel_root/1526964810f57c8028e6760b30faecdd/external/jpeg_archive/temp13980778680031230739/jpegsrc.v9c.tar.gz: Checksum was 682aee469c3ca857c4c38c37a6edadbfca4b04d42e56613b11590ec6aa4a278d but wanted 1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730
ERROR: Analysis of target '//python/pip_package:build_pip_package' failed; build aborted: Analysis failed
INFO: Elapsed time: 7.022s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (31 packages loaded, 2049 targets configured)
FAILED: Build did NOT complete successfully (31 packages loaded, 2049 targets configured)
The command '/bin/sh -c ln -s '/opt/conda/lib/python3.7/site-packages/numpy/core/include/numpy' /usr/include/numpy &&     sed -i '[email protected]@python3.7@g' python.BUILD &&     sed -i 's@glob(\[@glob(["include/numpy/\*\*/*.h", @g' python.BUILD &&     sed -i 's@: \[@: ["include/numpy", @g' python.BUILD &&     sed -i 's@650250979303a649e21f87b5ccd02672af1ea6954b911342ea491f351ceb7122@1e9793e1c6ba66e7e0b6e5fe7fd0f9e935cc697854d5737adec54d93e5b3f730@g' WORKSPACE &&     sed -i 's@rules_cc-master@rules_cc-main@g' WORKSPACE &&     sed -i 's@rules_cc/archive/master@rules_cc/archive/main@g' WORKSPACE &&     bazel build -c opt python/pip_package:build_pip_package --incompatible_remove_legacy_whole_archive=0' returned a non-zero code: 1

I am not very familiar with bazel, but it seems that one of the dependencies, i.e. jpeg_archive that bazel is in charge of installing is not found on the remote repository.

Did you happen to encounter this problem during your experiments ?

Looking forward to hear back from you.
Best regards.

Minigrid environments not working

Thanks for this Pytorch implementation of Dreamer-v2! I'm trying to get the code working with the Minigrid environments. However, I'm encountering the following error:

Traceback (most recent call last):
File "/home/steph/projects/pydreamer/train.py", line 612, in
run(conf)
File "/home/steph/projects/pydreamer/train.py", line 227, in run
model.training_step(obs,
File "/home/steph/projects/pydreamer/pydreamer/models/dreamer.py", line 122, in training_step
loss_probe, metrics_probe, tensors_probe = self.probe_model.training_step(features.detach(), obs)
File "/home/steph/projects/pydreamer/pydreamer/models/probes.py", line 38, in training_step
map_coord = insert_dim(obs['map_coord'], 2, I)
KeyError: 'map_coord'

Here's the command I ran:
xvfb-run -a -s "-screen 0 1400x900x24" python train.py --config defaults minigrid debug --env_id MiniGrid-Empty-8x8-v0 --device cuda

Dataset in test environment

Hi! Thanks for your work on pytorch implementation of dreamerv2. It helps me a lot.

I am just curious on the reason for building two different datasets for the both train and test environments.
It seems not crucial for the training, but why do we need a dataset in test env?
Please let me know if this is obviously necessary and I am missing something.

require_grad is False when actor_grad is dynamic

When running the code on DMC, because the actor_grad is dynamics; therefore, loss_policy would be -value_target. value_target is not dependent on the actor's policy distribution, and so, loss_policy does not have any gradient flowing through it with respect to the actor's parameters.
The assertion will be assert (False and True) or not True, since loss_policy does not require gradients. Therefore, the assertion becomes False.
How can we fix it?

Does dreamer need env fixed initialization?

I'm doing some robot manipulation tasks by dreamer. But the result now can't get very well (worse than SAC, and don't converge). I suspect it's because my env.reset() lead to random states. At the beginning of each round, the RNN assumes that the previous state is all 0, which seems unreasonable

I found that the paper of PlaNet mentioned that the initial state of the round is fixed. And the dreamer use its RSSM. So I wonder if dreamer can handle randomized starting environments in use.

Generator checkpoint not found

When trying to train it, despite the latest.pt file being in the checkpoints folder it always finishes with Generator Checkpoint not found, It seems as the generator checkpoint is giving back None in model_step. Also checkpoints are only saved on debug CPU mode, not the normal GPU mode.
Also why does the model not learn despite making the prefill steps really large, and the issue after the prefill steps are completed the program just exits, again because it seems like checkpoint is not found. These issues are not dependency issues so it seems.

Thank you

Batchnorm: Expected more than 1 value per channel

When the generator is trying run a NetworkPolicy it has a batch size of 1 which doesn't work well with BatchNorm Layers (see error in title). The error is due to the fact that the mean and var of the BatchNorm are still getting updated even with torch.no_grad and the error for batch size 1 is probably to indicate that this is not the intended behavior.

Do you think it is okay to run the NetworkPolicy in eval mode?

procgen env

sorry I'm sort of new in this field, Can I run this on procgen enviroment?

About sharing training logs

Is it possible to share the training logs, which would allow us to better understand the experiment, e.g. convergence results for each loss at the end, etc. Thanks!

V3 Implemetation

I'm just wondering if/when DreamerV3 will be pushed to this repo.

about the generators

Hi!

I don't fully understand how the code manages the available hardware resources, and I could use some advice on how to accelerate training. e.g. in an environment with multiple GPUs and multiple CPUs, what changes should I do to make sure I make use of these resources?

Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.