Coder Social home page Coder Social logo

causal-mbrl's Introduction

Causal-MBRL

Code style: black

cmrl(short for Causal-MBRL) is a toolbox for facilitating the development of Causal Model-based Reinforcement learning algorithms. It uses Stable-Baselines3 as model-free engine and allows flexible use of causal models.

cmrl is inspired by MBRL-Lib. Unlike MBRL-Lib, cmrl focuses on the causal characteristics of the model. It supports the learning of different types of causal models and can use any model free algorithm on these models. It uses Emei as the reinforcement learning environment by default, which is a re-encapsulation of Openai Gym.

Main Features

Thanks to the decoupling between the environment-model and the model-free algorithm, cmrl supports all on-policy and off-policy reinforcement learning algorithms in Stable-Baselines3 and SB3-Contrib. Meanwhile, cmrl is consistent with a number of utilities Stable-Baselines3 (e.g. Logger, Replay-buffer, Callback, etc.).

Although it supports many model-free algorithms, the focus of cmrl is the learning of causal models. cmrl uses VecFakeEnv to build a fake environment and conduct online reinforcement learning on it. Each 'VecFakeEnv' corresponds to a dynamics, which is composed of three parts, namely, Transition, Reward-Mech(short for reward mechanism) and Termination-Mech, look at the class diagram:

classDiagram

    BaseDynamics o-- BaseTransition
    BaseDynamics o-- BaseRewardMech
    BaseDynamics o-- BaseTerminationMech

    class BaseDynamics {
         + transition: BaseTransition
         + reward_mech: BaseRewardMech
         + termination_mech: BaseTerminationMech
         + transition_graph: BaseGraph
         + reward_mech_graph: BaseGraph
         + termination_mech_graph: BaseGraph
         + learn()
         + save()
         + load()
    }

    class BaseTransition {
        + obs_size: int
        + action_size: int
        + forward()
    }

    class BaseRewardMech {
        + obs_size: int
        + action_size: int
        + forward()
    }

    class BaseTerminationMech {
        + obs_size: int
        + action_size: int
        + forward()
    }
Loading

cmrl encapsulates the neural networks commonly used in causal-model-based RL, including PlainEnsembleMLP , ExternalMaskEnsembleMLP and so on. For any mechanism in dynamics, it should be a subclass of any MLP and its corresponding base class. For example, look at the class diagram of PlainTransition:

classDiagram

    EnsembleMLP <|--  PlainEnsembleMLP
    BaseTransition  <|-- PlainTransition
    PlainEnsembleMLP <|-- PlainTransition

    class PlainTransition {
        + obs_size: int
        + action_size: int
        + forward()
    }

    class BaseTransition {
        + obs_size: int
        + action_size: int
        + forward()
    }

    class PlainEnsembleMLP {
        + ensemble_num: int
        + elite_num: int
         + save()
         + load()
    }

Loading

Installation

install by cloning from github

# clone the repository
git clone https://github.com/FrankTianTT/causal-mbrl.git
cd causal-mbrl
# create conda env
conda create -n cmrl python=3.8
conda activate cmrl
# install torch
conda install pytorch -c pytorch
# install cmrl and its dependent packages
pip install -e .

for pytorch

# for MacOS
conda install pytorch -c pytorch
# for Linux
conda install pytorch pytorch-cuda=11.6 -c pytorch -c nvidia

for KCIT and RCIT

conda install -c conda-forge r-base
conda install -c conda-forge r-devtools
R
# Install the RCIT from Github. 
install.packages("devtools")
library(devtools)
install_github("ericstrobl/RCIT")
library(RCIT)

# Install R libraries for RCIT
install.packages("MASS")
install.packages("momentchi2")
install.packages("devtools")

# test RCIT
RCIT(rnorm(1000),rnorm(1000),rnorm(1000))

install using pip

coming soon.

Usage

python -m cmrl.exmaples.main

Contributing

see CONTRIBUTING for details.

causal-mbrl's People

Contributors

franktiantt avatar wz139704646 avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.