Coder Social home page Coder Social logo

tcbegley / rl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from pytorch/rl

0.0 0.0 0.0 33.63 MB

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

License: MIT License

Shell 3.30% C++ 0.58% Python 95.82% Batchfile 0.26% PowerShell 0.04%

rl's Issues

[nanoChatGPT]

Print accuracy in reward training loop evaluation

[nanoChatGPT] Don't shuffle val set?

See here

This is probably actually fine, because we don't perform a full pass over the validation set when calculating validation metrics, we sample instead, so we likely do want some randomness so as to not always validate on the same subset of the validation data.

[nanoChatGPT] How to represent reward model

The reward model is trained on proposed answers to a prompt which come in pairs, one marked as chosen, the other as rejected. The reward model should output a high score on the chosen answers, and a low score on the rejected.

It seems tricky to come up with a clean programming pattern for this using tensorclasses. Ideally it would be nice to represent the data using a tensorclass, and use TensorDictModule to perform a single forward pass on the data.

We have a tensorclass roughly of the form

@tensorclass
class Data:
    prompt: torch.Tensor
    chosen: torch.Tensor
    rejected: torch.Tensor

We need to do two forward passes, subtract the results and backpropagate. So we end up doing something roughly like this

chosen_loss = model(batch.prompt, batch.chosen)
rejected_loss = model(batch.prompt, batch.rejected)
loss = -torch.sigmoid(chosen_loss - rejected_loss)

which doesn't make use of TensorDictModule. One possibility would be to do something like

chosen_model = TensorDictModule(model, ["prompt", "chosen"], ["chosen_loss"])
rejected_model = TensorDictModule(model, ["prompt", "rejected"], ["rejected_loss"])
chosen_model(batch)
rejected_model(batch)
loss = -torch.sigmoid(batch.chosen_loss - batch.rejected_loss)

We could even then combine these into a single call with TensorDictSequential. The only problem is that this feels more complicated and hard to follow.

Similarly we could combine the forward passes of chosen and rejected examples into a single forward pass by adding in a flag which indicates the sign to be used for that example when aggregating the scores, but similarly that becomes more complex and hard to follow.

[nanoChatGPT] Configure logging

Scripts currently use print statements for logging, would be nice to use logging module so that logs can be redirected and configured more easily.

Don't persist iteration number in model checkpoints

Currently training the reward model starts at the iteration number of the transformer checkpoint, which is weird. We should just start counting iterations from 0 in the reward model training loop (assuming that the reward model is being trained from scratch, if loading a reward model from a checkpoint then we can count from the checkpointed iteration number of the reward model) regardless of how many iterations the transformer was trained for.

[nanoChatGPT] Clean up config

ATM we have lots of options that are either redundant or not used, we should streamline and update the comments accordingly.

Example: dataset choice should be deprecated, as should model choice.

Single config for entire pipeline?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.