Coder Social home page Coder Social logo

learn2learn's Introduction


Build Status

learn2learn is a software library for meta-learning research.

learn2learn builds on top of PyTorch to accelerate two aspects of the meta-learning research cycle:

  • fast prototyping, essential in letting researchers quickly try new ideas, and
  • correct reproducibility, ensuring that these ideas are evaluated fairly.

learn2learn provides low-level utilities and unified interface to create new algorithms and domains, together with high-quality implementations of existing algorithms and standardized benchmarks. It retains compatibility with torchvision, torchaudio, torchtext, cherry, and any other PyTorch-based library you might be using.

Overview

  • learn2learn.data: TaskDataset and transforms to create few-shot tasks from any PyTorch dataset.
  • learn2learn.vision: Models, datasets, and benchmarks for computer vision and few-shot learning.
  • learn2learn.gym: Environment and utilities for meta-reinforcement learning.
  • learn2learn.algorithms: High-level wrappers for existing meta-learning algorithms.
  • learn2learn.optim: Utilities and algorithms for differentiable optimization and meta-descent.

Resources

Installation

pip install learn2learn

Snippets & Examples

The following snippets provide a sneak peek at the functionalities of learn2learn.

High-level Wrappers

Few-Shot Learning with MAML

For more algorithms (ProtoNets, ANIL, Meta-SGD, Reptile, Meta-Curvature, KFO) refer to the examples folder. Most of them can be implemented with with the GBML wrapper. (documentation).

maml = l2l.algorithms.MAML(model, lr=0.1)
opt = torch.optim.SGD(maml.parameters(), lr=0.001)
for iteration in range(10):
    opt.zero_grad()
    task_model = maml.clone()  # torch.clone() for nn.Modules
    adaptation_loss = compute_loss(task_model)
    task_model.adapt(adaptation_loss)  # computes gradient, update task_model in-place
    evaluation_loss = compute_loss(task_model)
    evaluation_loss.backward()  # gradients w.r.t. maml.parameters()
    opt.step()

Meta-Descent with Hypergradient

Learn any kind of optimization algorithm with the LearnableOptimizer. (example and documentation)

linear = nn.Linear(784, 10)
transform = l2l.optim.ModuleTransform(l2l.nn.Scale)
metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01)  # metaopt has .step()
opt = torch.optim.SGD(metaopt.parameters(), lr=0.001)  # metaopt also has .parameters()

metaopt.zero_grad()
opt.zero_grad()
error = loss(linear(X), y)
error.backward()
opt.step()  # update metaopt
metaopt.step()  # update linear

Learning Domains

Custom Few-Shot Dataset

Many standardized datasets (Omniglot, mini-/tiered-ImageNet, FC100, CIFAR-FS) are readily available in learn2learn.vision.datasets. (documentation)

dataset = l2l.data.MetaDataset(MyDataset())  # any PyTorch dataset
transforms = [  # Easy to define your own transform
    l2l.data.transforms.NWays(dataset, n=5),
    l2l.data.transforms.KShots(dataset, k=1),
    l2l.data.transforms.LoadData(dataset),
]
taskset = TaskDataset(dataset, transforms, num_tasks=20000)
for task in taskset:
    X, y = task
    # Meta-train on the task

Environments and Utilities for Meta-RL

Parallelize your own meta-environments with AsyncVectorEnv, or use the standardized ones. (documentation)

def make_env():
    env = l2l.gym.HalfCheetahForwardBackwardEnv()
    env = cherry.envs.ActionSpaceScaler(env)
    return env

env = l2l.gym.AsyncVectorEnv([make_env for _ in range(16)])  # uses 16 threads
for task_config in env.sample_tasks(20):
    env.set_task(task)  # all threads receive the same task
    state = env.reset()  # use standard Gym API
    action = my_policy(env)
    env.step(action)

Low-Level Utilities

Differentiable Optimization

Learn and differentiate through updates of PyTorch Modules. (documentation)

model = MyModel()
transform = l2l.optim.KroneckerTransform(l2l.nn.KroneckerLinear)
learned_update = l2l.optim.ParameterUpdate(  # learnable update function
        model.parameters(), transform)
clone = l2l.clone_module(model)  # torch.clone() for nn.Modules
error = loss(clone(X), y)
updates = learned_update(  # similar API as torch.autograd.grad
    error,
    clone.parameters(),
    create_graph=True,
)
l2l.update_module(clone, updates=updates)
loss(clone(X), y).backward()  # Gradients w.r.t model.parameters() and learned_update.parameters()

Changelog

A human-readable changelog is available in the CHANGELOG.md file.

Citation

To cite the learn2learn repository in your academic publications, please use the following reference.

Sebastien M.R. Arnold, Praateek Mahajan, Debajyoti Datta, Ian Bunner. "learn2learn". https://github.com/learnables/learn2learn, 2019.

You can also use the following Bibtex entry.

@misc{learn2learn2019,
    author       = {Sebastien M.R. Arnold, Praateek Mahajan, Debajyoti Datta, Ian Bunner},
    title        = {learn2learn},
    month        = sep,
    year         = 2019,
    url          = {https://github.com/learnables/learn2learn}
    }

Acknowledgements & Friends

  1. The RL environments are adapted from Tristan Deleu's implementations and from the ProMP repository. Both shared with permission, under the MIT License.
  2. TorchMeta is similar library, with a focus on datasets for supervised meta-learning.
  3. higher is a PyTorch library that enables differentiating through optimization inner-loops. While they monkey-patch nn.Module to be stateless, learn2learn retains the stateful PyTorch look-and-feel. For more information, refer to their ArXiv paper.

learn2learn's People

Contributors

seba-1511 avatar praateekmahajan avatar debajyotidatta avatar ian-bunner avatar kostis-s-z avatar janbolle avatar ewinapun avatar joemzhao avatar vfdev-5 avatar

Watchers

 avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.