Coder Social home page Coder Social logo

deep_q_network's Introduction

Deep Q-Learning with Atari Pong

A program that features training a Deep-Q Learner for Atari Pong

  • Note that this would normally be a larger task, implying creating an architecture, choosing hyperparameters, testing, and trying again until you get a model that works. Fortunately, this has been done for us and a set of hyperparameters and architecture that works has been selected.
  • Note to use a GPU while training - otherwise, training will take very long (probably days rather than hours). Use a high-RAM session if that option is available.

This program contains two files:

  • Deep_Q_Network.ipynb: This is where the code is written.
  • model_pretrain.pth: A pre-trained model that isn't quite there yet. You'll need to train for about 1M more frames to get this working correctly. 

This program fulfills the following requirements:

  1. Exploitation policy:
    • Recall the exploration vs exploitation tradeoff of reinforcement learning. Below you can see that the exploration component is already taken care of. You will need to implement the exploitation part. Pass the state into the Q-learner and get the action your learner thinks is best.
    • Tip #1: You can pass a state into the learner via self(state), and you can detach the output via .detatch().cpu().numpy() so you can work with it.
  2. Compute loss:
    • Recall the loss for deep Q-Learning is given by:
      • (f(state, actions) - (reward + gamma * max(ftarget(next_state)))^2
    • Tip #1: You can access the max element of each instance in the batch via:
      • <tensor>.detach().max(1)[0]
    • Tip #2: Remember that you do not want to include future states if the “done” flag is true. You can figure out which states are done by multiplying the output of the target model by (1-done).
    • Tip #3: PyTorch has an MSE loss function already implemented. Simply use:
      • torch.nn.MSELoss(reduction='sum')(output, target)
    • Tip #4: You can quickly index the actions via the <tensor>.gather() command.
  3. Sample from the replay buffer:
    • In order to have diverse training instances for each update, sample batch_size frames and return a tuple of state, action, reward, next_state, done.
    • Tip #1: random.sample from the random module is useful here, as is the zip command.
  4. Periodically save the model while training:
    • Don’t do this every frame, and it should save automatically after training, but it’s a good idea to save intermediate results in case your session crashes, you lose internet, etc.
    • Tip #1: There is code for saving the model at the end of the training script.

deep_q_network's People

Contributors

tinanemati avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.