Coder Social home page Coder Social logo

llamagym's Introduction

Llama Gym

Fine-tune LLM agents with online reinforcement learning

Python Version

๐Ÿ”— Agents for Web Data Extraction ย ย โ€ขย ย  ๐Ÿฆ Twitter

LlamaGym

"Agents" originated in reinforcement learning, where they learn by interacting with an environment and receiving a reward signal. However, LLM-based agents today do not learn online (i.e. continuously in real time) via reinforcement.

OpenAI created Gym to standardize and simplify RL environments, but if you try dropping an LLM-based agent into a Gym environment for training, you'd find it's still quite a bit of code to handle LLM conversation context, episode batches, reward assignment, PPO setup, and more.

LlamaGym seeks to simplify fine-tuning LLM agents with RL. Right now, it's a single Agent abstract class that handles all the issues mentioned above, letting you quickly iterate and experiment with agent prompting & hyperparameters across any Gym environment.

Usage

Fine-tuning an LLM-based agent to play in a Gym-style environment with RL has never been easier! Once you install LlamaGym...

pip install llamagym

First, implement 3 abstract methods on the Agent class:

from llamagym import Agent

class BlackjackAgent(Agent):
    def get_system_prompt(self) -> str:
        return "You are an expert blackjack player."

    def format_observation(self, observation) -> str:
        return f"Your current total is {observation[0]}"

    def extract_action(self, response: str):
        return 0 if "stay" in response else 1

Then, define your base LLM (as you would for any fine-tuning job) and instantiate your agent:

model = AutoModelForCausalLMWithValueHead.from_pretrained("Llama-2-7b").to(device)
tokenizer = AutoTokenizer.from_pretrained("Llama-2-7b")
agent = BlackjackAgent(model, tokenizer, device)

Finally, write your RL loop as usual and simply call your agent to act, reward, and terminate:

env = gym.make("Blackjack-v1")

for episode in trange(5000):
    observation, info = env.reset()
    done = False

    while not done:
        action = agent.act(observation) # act based on observation
        observation, reward, terminated, truncated, info = env.step(action)
        agent.assign_reward(reward) # provide reward to agent
        done = terminated or truncated

    train_stats = agent.terminate_episode() # trains if batch is full

Some reminders:

  • above code snippets are mildly simplified above but a fully working example is available in examples/blackjack.py
  • getting online RL to converge is notoriously difficult so you'll have to mess with hyperparameters to see improvement
    • your model may also benefit from a supervised fine-tuning stage on sampled trajectories before running RL (we may add this feature in the future)
  • our implementation values simplicity so is not as compute efficient as e.g. Lamorel, but easier to start playing around with
  • LlamaGym is a weekend project and still a WIP, but we love contributions!

Relevant Work

Citation

bibtex
@misc{pandey2024llamagym,
  title        = {LlamaGym: Fine-tune LLM agents with Online Reinforcement Learning},
  author       = {Rohan Pandey},
  year         = {2024},
  howpublished = {GitHub},
  url          = {https://github.com/KhoomeiK/LlamaGym}
}

llamagym's People

Contributors

khoomeik avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.