Coder Social home page Coder Social logo

deep-q-learning's Introduction

Deep-Q-Learning

Reinforcement Learning with PyTorch and OpenAI-Gym. This repository has implementation for Deep-Q-Learning Algorithm and Dueling Double Deep-Q-Learning Algorithm.

Python 3 Pytorch GYM

Results

Pong

Breakout

Breakout has multiple levels. The bricks are reset and the agent continues to get points in the next level with its remaining lives. Such case results in sudden increase in the reward.

Install Dependencies

pip install -r requirements.txt

Train

Create folder for checkpoints

mkdir -p Checkpoints/Pong
mkdir -p Checkpoints/Breakout

Set the environment parameter before running.
Run DQN

python deep_q_learning.py

Run DDDQN

python duel_double_deep_q_learning.py

Test

Set MODEL_PATH and environment parameters in play_model.py

python play_model.py

Takeaways

  • DQN is tricky and it takes numerous attempts and patience to get things working from scratch. For quicker results one can start with OpenAI baselines. However, trying from scratch provides deeper insights about parameters and their effect on the algorithm's performance.
  • Pytorch specific: The target needs to be detached from the computation graph or the target Q values need to be calculated with no_grad scope during model optimization. This ensures that the optimizer does not update the weights of the target network during backpropagation.
  • Be aware of the data types of tensors. Replay memory should have uint8 type but that needs to be converted to float32 while training. Reward and terminal tensors should also be float32 type.
  • Gradient clipping and reward clipping both are extremely important for stability. Figure shown below illustrates a failed training example where the reward falls as soon as the agent gets a very high reward (Breakout). Sudden large reward leads to unstable gradient updates during backpropagation.

  • Batch normalization is not a good choice for DQN and leads to increased training time.
  • Huber loss is an essential alternative for MSE loss. Simpler problems such as Cartpole and Pong where the planning horizon is short and the focus is more on immediate rewards, MSE loss performs fairly well.
  • Parameters such as GAMMA, BATCH_SIZE, LEARNING_RATE, EXPLORATION_FRAMES and MEMORY_BUFFER determine the training time. Small deviations in these parameters should not affect the stability drastically.
  • EPSILON scheduling is important during initial training phase. In the later phase of training most of the learning happens with EPSILON_END.
  • Gym Deterministic environments are the ones Deepmind folks used for their 2013 paper
  • Model optimization happens once in POLICY_UPDATE_INTERVAL steps and target update happens once in TARGET_UPDATE_INTERVAL steps. For Deterministic environments each step skips 4 frames. So if POLICY_UPDATE_INTERVAL=4 and TARGET_UPDATE_INTERVAL=100 then backpropagation happens once in 16 game frames and target update happens once in 400 game frames.
  • Understanding the pre-processing is confusing. I would recommend this article to get a clear idea about how the Atari environments work.
  • There are two terminal flags. One is true every time the agent loses a life and the other is true only at the end of an episode. Replay memory should be filled with the one indicating lives lost.
  • Random FIRE actions at the start of an episode avoids learning a suboptimal policy.
  • Start training when memory buffer has some good number of samples. This avoids learning a suboptimal policy.
  • Don't give up. Deep RL is frustrating to train but it is also an engineering marvel.

deep-q-learning's People

Contributors

prabinrath avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.