Coder Social home page Coder Social logo

itomigna2 / muesli-lunarlander Goto Github PK

View Code? Open in Web Editor NEW
13.0 2.0 4.0 230 KB

Muesli RL algorithm implementation (PyTorch) (LunarLander-v2)

License: MIT License

Jupyter Notebook 87.61% Dockerfile 0.58% Python 11.81%
colab reinforcement-learning deep-learning lunarlander-v2 model-based-rl muesli muzero

muesli-lunarlander's People

Contributors

itomigna2 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

muesli-lunarlander's Issues

Inefficient dtype conversion

During efficiency optimization, maybe I found the critical bug slowing down the code.

Some careless dtype conversions.

output_p , output_v, _ = target.representation_network(torch.from_numpy(stacked_state).float().unsqueeze(0).to(device))

x = x.div(params['norm_factor']).float()

After deleting .float() about above lines, experiment goes faster 5x than before.

I will commit soon after checking potential bug for that change.

Replay buffer is storing multiple copies of the most recent episode

Hi @Itomigna2,

This is very cool work. Thanks for sharing it.

I was testing out your implementation here and noticed a bug. For every new episode you add an entry in the replay buffers for state, action, policy, and reward. But because you are using the same same list over and over you do not get unique episodes in the replay buffer, instead you get references to the same list so they all end up being the exact same most recent episode.

You can check this with this by inserting the code below. Here we are summing the first episode in the replay buffer and then testing if that sum is exactly the same for all of the other episodes in the buffer. You will see that they are.

for i in range(episode_nums):
    #...
    print("Size of replay buffer: ", len(agent.action_replay))
    print("All actions same in replay buffer?: ", all([sum(agent.action_replay[0]) == sum(r) for r in agent.action_replay]))
    print("All rewards same in replay buffer? ", all([sum(agent.r_replay[0]) == sum(r) for r in agent.r_replay]))
    agent.update_weights_mu(target)

The fix for this is pretty straightforward. We just need to change all of the list.clear() to new lists. Currently you have this at the end of update_weights_mu. I would suggest moving them to the beginning of self_play_mu.

        # remove these 
        # self.state_traj.clear()
        # self.action_traj.clear()
        # self.P_traj.clear()
        # self.r_traj.clear()

        self.state_traj = []
        self.action_traj = []
        self.P_traj = []
        self.r_traj = []

Using your Muesli code for large scale training efforts.

Hello, my name is Connor. I'm a post doctoral researcher working at MILA. We looking for a codebase for Muesli to use in our video pretraining project (a joint venture between MILA and the Farama organization). The project's aim is to use internet-scale data to pretrain on video-action pairs in order to develop a foundation model to assist with downstream RL tasks.

Would you be available to discuss potential collaborations on this project? In particular, we are hoping that you could assist us with getting the code set up for our particular use cases.

Incorrect HPO params viewed by nni

I found NNI experiment params incorrectly viewed on the nni page.

The trial number of the (NNIManager) submitTrialJob is viewed
but real value used in the experiment came from (LocalV3.local) Trial parameter

It looks like similar to this issue
microsoft/nni#5726

I will check the nni configuration of this repo.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.