Coder Social home page Coder Social logo

jason-cky / rl_vae-pytorch Goto Github PK

View Code? Open in Web Editor NEW
1.0 2.0 0.0 15 KB

Pipeline to generate images and to train a Variational Autoencoder (VAE) for use in Deep RL environments

Python 100.00%
vae autoencoders pytorch-implementation reinforcement-learning reinforcement-learning-environments

rl_vae-pytorch's Introduction

Deep RL policies on Pybullet Environments

This repo is a pytorch implementation of training a variational autoencoder (VAE). This is written to train a VAE for use in a RL environment, and contains code to generate images from random exploration of various RL environments from pybullet and RLBench for training.

Beta-VAE can also be trained by setting the Beta value to any value greater than 1.

Dependencies:

  • CUDA >= 10.2
  • RLBench, only if you want to use RLBench environments to train VAE

How to use

  • Clone this repo
  • pip install -r requirements.txt

Generating data from openai gym environment

python generate_data.py
usage: generate_data.py [-h] --env ENV --num_samples NUM_SAMPLES
                        [--max_ep_len MAX_EP_LEN] [--seed SEED] [--rlbench]
                        [--view {wrist_rgb,front_rgb,left_shoulder_rgb,right_shoulder_rgb}]

optional arguments:
  -h, --help            show this help message and exit
  --env ENV             environment_id
  --num_samples NUM_SAMPLES
                        specify number of image samples to generate
  --max_ep_len MAX_EP_LEN
                        Maximum length of an episode
  --seed SEED           seed number for reproducibility
  --rlbench             if true, use rlbench environment wrappers
  --view {wrist_rgb,front_rgb,left_shoulder_rgb,right_shoulder_rgb}
                        choose the type of camera view to generate image (only
                        for RLBench envs)

Training VAE

usage: train_vae.py [-h] [--dir DIR] [--seed SEED] [--num_workers NUM_WORKERS]
                    [--batch_size BATCH_SIZE] [--epochs EPOCHS] [--beta BETA]
                    [--lr LR] [--ngpu NGPU] [--save_freq SAVE_FREQ]
                    [--log_freq LOG_FREQ] [--save_dir SAVE_DIR]

optional arguments:
  -h, --help            show this help message and exit
  --dir DIR             path to image folders
  --seed SEED           seed number for reproducibility
  --num_workers NUM_WORKERS
                        number of workers for dataloaders
  --batch_size BATCH_SIZE
                        batch size
  --epochs EPOCHS       Number of epochs
  --beta BETA           Weighing value for KLD in B-VAE
  --lr LR               Learning Rate
  --ngpu NGPU           number of gpus to use
  --save_freq SAVE_FREQ
                        save weights every <x> iterations
  --log_freq LOG_FREQ   log losses every <x> iterations
  --save_dir SAVE_DIR   path to save weights and logs

rl_vae-pytorch's People

Contributors

jason-cky avatar

Stargazers

Srinivas Venkatanarayanan avatar

Watchers

James Cloos avatar  avatar

rl_vae-pytorch's Issues

question

If there are some examples to use vae like this to train rl, it will be better. Your code is only training vae.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.