Coder Social home page Coder Social logo

muzero-pytorch's Introduction

muzero-pytorch

Pytorch Implementation of MuZero : "Mastering Atari , Go, Chess and Shogi by Planning with a Learned Model" based on pseudo-code provided by the authors

Note: This implementation has just been tested on CartPole-v1 and would required modifications(in config folder) for other environments

Installation

  • Python 3.8, 3.9
  •   cd muzero-pytorch
      pip install -r requirements.txt

Usage:

  • Train: python main.py --env CartPole-v1 --case classic_control --opr train --force
  • Test: python main.py --env CartPole-v1 --case classic_control --opr test
  • Visualize results :
    • tensorboard --logdir=<result_dir_path>
    • if --use_wandb was passed, you can visualize results in wandb as well.
Required Arguments Description
--env Name of the environment
--case {atari,classic_control,box2d} It's used for switching between different domains(default: None)
--opr {train,test} select the operation to be performed
Optional Arguments Description
--value_loss_coeff Scale for value loss (default: None)
--revisit_policy_search_rate Rate at which target policy is re-estimated (default:None)( only valid if --use_target_model is enabled)
--use_priority Uses priority for data sampling in replay buffer. Also, priority for new data is calculated based on loss (default: False)
--use_max_priority Forces max priority assignment for new incoming data in replay buffer (only valid if --use_priority is enabled) (default: False)
--use_target_model Use target model for bootstrap value estimation (default: False)
--result_dir Directory Path to store results (defaut: current working directory)
--no_cuda no cuda usage (default: False)
--no_mps no mps (Metal Performance Shaders) usage (default: False)
--debug If enables, logs additional values (default:False)
--render Renders the environment (default: False)
--force Overrides past results (default: False)
--seed seed (default: 0)
--num_actors Number of actors running concurrently (default: 32)
--test_episodes Evaluation episode count (default: 10)
--use_wandb Logs console and tensorboard data on wandb (default: False)

Note: default: None => Values are loaded from the corresponding config

Training

CartPole-v1

  • Curves represents model evaluation for 5 episodes at 100 step training interval.
  • Also, each curve is a mean scores over 5 runs (seeds : [0,100,200,300,400])

muzero-pytorch's People

Contributors

bampt avatar dependabot[bot] avatar grifball avatar koulanurag avatar leovandriel avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.