Coder Social home page Coder Social logo

vin_tensorflow's Introduction

This is an implementation of Value Iteration Networks (VIN) in TensorFlow to reproduce the results.(PyTorch version)

Architecture of Value Iteration Network

Key idea

  • A fully differentiable neural network with a 'planning' sub-module.
  • Value Iteration = Conv Layer + Channel-wise Max Pooling
  • Generalize better than reactive policies for new, unseen tasks.

Learned Reward Image and Its Value Images for each VI Iteration

Visualization Grid world Reward Image Value Images
8x8
16x16
28x28

Dependencies

This repository requires following packages:

  • Python >= 3.6
  • Numpy >= 1.12.1
  • TensorFlow >= 1.0
  • SciPy >= 0.19.0

Datasets

Each data sample consists of (x, y) coordinates of current state in grid world, followed by an obstacle image and a goal image.

Dataset size 8x8 16x16 28x28
Train set 77760 776440 4510695
Test set 12960 129440 751905

Running Experiment: Training

Grid world 8x8

python run.py --datafile data/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

Grid world 16x16

python run.py --datafile data/gridworld_16x16.npz --imsize 16 --lr 0.008 --epochs 30 --k 20 --batch_size 128

Grid world 28x28

python run.py --datafile data/gridworld_28x28.npz --imsize 28 --lr 0.003 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. From: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • ch_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • ch_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • ch_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

Benchmarks

GPU: TITAN X

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
TensorFlow 99.03% 90.2% 82%
PyTorch 99.16% 92.44% 88.20%

Speed with GPU

Speed per epoch 8x8 16x16 28x28
TensorFlow 4s 25s 165s
PyTorch 3s 15s 100s

Frequently Asked Questions

  • Q: How to get reward image from observation ?

    • A: Observation image has 2 channels. First channel is obstacle image (0: free, 1: obstacle). Second channel is goal image (0: free, 10: goal). For example, in 8x8 grid world, the shape of an input tensor with batch size 128 is [128, 2, 8, 8]. Then it is fed into a convolutional layer with [3, 3] filter and 150 feature maps, followed by another convolutional layer with [3, 3] filter and 1 feature map. The shape of the output tensor is [128, 1, 8, 8]. This is the reward image.
  • Q: What is exactly transition model, and how to obtain value image by VI-module from reward image ?

    • A: Let us assume batch size is 128 under 8x8 grid world. Once we obtain the reward image with shape [128, 1, 8, 8], we do convolutional layer for q layers in VI module. The [3, 3] filter represents the transition probabilities. There is a set of 10 filters, each for generating a feature map in q layers. Each feature map corresponds to an "action". Note that this is larger than real available actions which is only 8. Then we do a channel-wise Max Pooling to obtain the value image with shape [128, 1, 8, 8]. Finally we stack this value image with reward image for a new VI iteration.

References

Further Readings

vin_tensorflow's People

Contributors

zuoxingdong avatar

Stargazers

PGSA Dev avatar Yijie avatar justiceli avatar  avatar Minsoo Kim avatar Dan McNamee avatar Vladimir Kravtsov avatar anxuthu avatar Thiago P. Bueno avatar ZhYuan avatar Kiji Marudan avatar Sean avatar Jyothir S V avatar Sadakuni avatar  avatar Adrien Turiot avatar WAH avatar Rockson Zeta avatar Gao Fangshu avatar  avatar  avatar lnj0532 avatar cn3c3p avatar  avatar  avatar magnus avatar SHIRE avatar  avatar Lequan Yu avatar Han avatar Takuya Wakisaka avatar Claudio Greco avatar  avatar  avatar Vivek avatar Yiming Lin avatar Johanna Hansen avatar Dong Li avatar KPK999 avatar Alex Gaziev avatar Shawn Lue avatar Liang Depeng avatar 爱可可-爱生活 avatar Surya Bhupatiraju avatar  avatar Emre Şafak avatar TENSORTALK avatar Aviv Tamar avatar Joseph Cheng avatar Arthur Juliani avatar Anirudh Vemula avatar Gu Wang avatar

Watchers

 avatar  avatar Aviv Tamar avatar  avatar  avatar

vin_tensorflow's Issues

Grid world 28x28 can't run successfully.

I ran the command in the terminal as bellow:


python run.py --datafile data/gridworld_16x16.npz --imsize 16 --lr 0.008 --epochs 30 --k 20 --batch_size 128

but it shows error as following:
Traceback (most recent call last):
File "/home/cr/PycharmProjects/VIN_TensorFlow/run.py", line 118, in
trainset = Dataset(args.datafile, mode='train', imsize=args.imsize)
File "/home/cr/PycharmProjects/VIN_TensorFlow/dataset.py", line 6, in init
data = np.load(filepath).items()[0][1][0]
File "/usr/local/lib/python2.7/dist-packages/numpy/lib/npyio.py", line 249, in items
return [(f, self[f]) for f in self.files]
File "/usr/local/lib/python2.7/dist-packages/numpy/lib/npyio.py", line 235, in getitem
pickle_kwargs=self.pickle_kwargs)
File "/usr/local/lib/python2.7/dist-packages/numpy/lib/format.py", line 650, in read_array
array = pickle.load(fp, **pickle_kwargs)
cPickle.UnpicklingError: BINUNICODE pickle has negative byte count

Why the error appears?
Best regards.

Number of actions

hi, Dong. I am very appreciate your work on VIN.
But I wonder why the number of Q layer, ch_q, is set as 10.
In my view, there are 9 potential positions for each position in the map according to the agent's action. So it means that there should be 9 actions and 9 layers of Q.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.