Coder Social home page Coder Social logo

charanhu / reinforcement_learning Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 1.0 380 KB

Reinforcement Learning with Stable Baselines3: Train and evaluate a CartPole agent using Stable Baselines3 library. Includes code for training, saving, and testing the model, along with a GIF visualization of the trained agent.

Jupyter Notebook 100.00%
action agent cartpole-v0 environment policy python reinforcement-learning reward

reinforcement_learning's Introduction

Reinforcement Learning with Stable Baselines3

This is a simple example of using Stable Baselines3, a library for reinforcement learning, to train an agent on the CartPole-v0 environment.

CartPole

Dependencies

Make sure you have the following dependencies installed:

  • stable-baselines3
  • gym
  • pyglet

You can install them using pip:

pip install stable-baselines3[extra]
pip install pyglet==1.5.27

Load Environment

First, we import the necessary dependencies and create an instance of the CartPole-v0 environment:

import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

environment_name = 'CartPole-v0'
env = gym.make(environment_name)

Training

To train the agent, we initialize the PPO algorithm and pass in the environment. We then call the learn method to start the training process:

log_path = os.path.join('Training', 'Logs')
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)
model.learn(total_timesteps=20000)

Training Logs

------------------------------------------
| time/                   |              |
|    fps                  | 581          |
|    iterations           | 10           |
|    time_elapsed         | 35           |
|    total_timesteps      | 20480        |
| train/                  |              |
|    approx_kl            | 0.0065331194 |
|    clip_fraction        | 0.0254       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.57        |
|    explained_variance   | 0.651        |
|    learning_rate        | 0.0003       |
|    loss                 | 7.36         |
|    n_updates            | 90           |
|    policy_gradient_loss | -0.0054      |
|    value_loss           | 23.8         |
------------------------------------------

Save and Load the Model

You can save the trained model to a file and load it later for evaluation or further training:

PPO_Path = os.path.join('Training', 'Saved Models', 'PPO_Model_Cartpole')
model.save(PPO_Path)
model = PPO.load(PPO_Path, env=env)

Evaluation

To evaluate the performance of the trained agent, you can use the evaluate_policy function:

evaluate_policy(model, env, n_eval_episodes=10, render=True)

Testing the Model

You can test the trained model by running episodes and observing its behavior:

episodes = 5
for episode in range(1, episodes + 1):
    obs = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        score += reward

    print('Episode: {} Score: {}'.format(episode, score))

env.close()

Viewing Training Logs

You can visualize the training progress using TensorBoard. First, specify the log directory and start TensorBoard:

training_log_path = os.path.join(log_path, 'PPO_1')
!tensorboard --logdir={training_log_path}

Then, open localhost:6006 in your browser to view the training logs.

Conclusion

Reinforcement learning with Stable Baselines3 is a powerful tool for training agents in various environments. By following the steps in this example, you can train, save, and evaluate a reinforcement learning agent for the CartPole-v0 environment.

reinforcement_learning's People

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.