Coder Social home page Coder Social logo

d-tiapkin / gflownet-rl Goto Github PK

View Code? Open in Web Editor NEW
22.0 2.0 0.0 19.84 MB

Repository for "Generative Flow Networks as Entropy-Regularized RL" (AISTATS-2024, Oral)

Home Page: https://arxiv.org/abs/2310.12934

License: MIT License

Python 100.00%
deep-learning gflownet pytorch reinforcement-learning

gflownet-rl's Introduction

Generative Flow Networks as Entropy-Regularized RL

Official code for the paper Generative Flow Networks as Entropy-Regularized RL.

Daniil Tiapkin*, Nikita Morozov*, Alexey Naumov, Dmitry Vetrov.

Installation

  • Create conda environment:
conda create -n gflownet-rl python=3.10
conda activate gflownet-rl
  • Install PyTorch with CUDA. For our experiments we used the following versions:
conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia

You can change pytorch-cuda=11.8 with pytorch-cuda=XX.X to match your version of CUDA.

  • Install core dependencies:
pip install -r requirements.txt

-(Optional) Install dependencies for molecule experiemtns

pip install -r requirements_mols.txt

You can change requirements_mols.txt to match your CUDA version by replacing cu118 by cuXXX.

Hypergrids

Code for this part heavily utlizes library torchgfn (https://github.com/GFNOrg/torchgfn).

Path to configurations (utlizes ml-collections library):

  • General configuration: hypergrid/experiments/config/general.py
  • Algorithm: hypergrid/experiments/config/algo.py
  • Environment: hypergrid/experiments/config/hypergrid.py

List of available algorithms:

  • Baselines: db, tb, subtb from torchgfn library;
  • Soft RL algorithms: soft_dqn, munchausen_dqn, sac.

Example of running the experiment on environment with height=20, ndim=4 with standard rewards, seed 3 on the algorithm soft_dqn.

    python run_hypergrid_exp.py --general experiments/config/general.py:3 --env experiments/config/hypergrid.py:standard --algo experiments/config/algo.py:soft_dqn --env.height 20 --env.ndim 4

To activate learnable backward policy for this setting

    python run_hypergrid_exp.py --general experiments/config/general.py:3 --env experiments/config/hypergrid.py:standard --algo experiments/config/algo.py:soft_dqn --env.height 20 --env.ndim 4 --algo.tied True --algo.uniform_pb False

Molecules

The presented experiments actively reuse the existing codebase for molecule generation experiments with GFlowNets (https://github.com/GFNOrg/gflownet/tree/subtb/mols).

Additional requirements for molecule experiments:

  • pandas rdkit torch_geometric h5py ray hydra (installation is available in requirements_mols.txt)

Path to configurations of MunchausenDQN (utilizes hydra library)

  • General configuration: mols/configs/soft_dqn.yaml
  • Algorithm: mols/configs/algorithm/soft_dqn.yaml
  • Environment: mols/configs/environment/block_mol.yaml

To run MunchausenDQN with configurations prescribed above, use

    python soft_dqn.py

To reporoduce baselines, run gflownet.py with required parameters, we refer to the original repository https://github.com/GFNOrg/gflownet for additional details.

Bit sequences

Examples of running TB, DB and SubTB baselines for word length k=8:

python bitseq/run.py --objective tb --k 8 --learning_rate 0.002
python bitseq/run.py --objective db --k 8 --learning_rate 0.002
python bitseq/run.py --objective subtb --k 8 --learning_rate 0.002 --subtb_lambda 1.9

Example of running SoftDQN:

python bitseq/run.py --objective softdqn --m_alpha 0.0 --k 8 --learning_rate 0.002 --leaf_coeff 2.0 

Example of running MunchausenDQN:

python bitseq/run.py --objective softdqn --m_alpha 0.15 --k 8 --learning_rate 0.002 --leaf_coeff 2.0 

Citation

@inproceedings{tiapkin2024generative,
  title={Generative flow networks as entropy-regularized rl},
  author={Tiapkin, Daniil and Morozov, Nikita and Naumov, Alexey and Vetrov, Dmitry P},
  booktitle={International Conference on Artificial Intelligence and Statistics},
  pages={4213--4221},
  year={2024},
  organization={PMLR}
}

gflownet-rl's People

Contributors

greatdrake avatar d-tiapkin avatar

Stargazers

Kohei Matsumoto avatar  avatar Ayhan Suleymanzade avatar  avatar Timofei Gritsaev avatar  avatar Nikita Kotelevskii avatar Jose Cohenca avatar TaeYoung avatar Haoran He avatar Emir Ceyani avatar Jinwoo Kim avatar  avatar Ksenia Kuvshinova avatar Amey Varhade avatar  avatar skara avatar  avatar  avatar Korelin Gleb avatar Artur Goldman avatar Денис avatar

Watchers

 avatar  avatar

gflownet-rl's Issues

Implementation correctness

Dear Authors,

First of all, I am very thankful for your repository. I got confused about the correctness of implementation in one part. For soft_dqn.py, the variable valid_v_target_next is getting multiplied with policy_sn in torch.sum module. According to my derivation, there should not be such kind of multiplication. Could you please point out how this policy_sn comes into the equation. Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.