Coder Social home page Coder Social logo

gabrielhuang / reptile-pytorch Goto Github PK

View Code? Open in Web Editor NEW
199.0 8.0 38.0 266 KB

A PyTorch implementation of OpenAI's REPTILE algorithm

License: BSD 2-Clause "Simplified" License

Python 16.72% Jupyter Notebook 83.28%
pytorch metalearning deep-learning deep-neural-networks few-shot supervised-learning openai

reptile-pytorch's Introduction

Reptile

PyTorch implementation of OpenAI's Reptile algorithm for supervised learning.

Currently, it runs on Omniglot but not yet on MiniImagenet.

The code has not been tested extensively. Contributions and feedback are more than welcome!

Omniglot meta-learning dataset

There is already an Omniglot dataset class in torchvision, however it seems to be more adapted for supervised-learning than few-shot learning.

The omniglot.py provides a way to sample K-shot N-way base-tasks from Omniglot, and various utilities to split meta-training sets as well as base-tasks.

Features

  • Monitor training with TensorboardX.
  • Interrupt and resume training.
  • Train and evaluate on Omniglot.
  • Meta-batch size > 1.
  • Train and evaluate on Mini-Imagenet.
  • Clarify Transductive vs. Non-transductive setting.
  • Add training curves in README.
  • Reproduce all settings from OpenAI's code.
  • Shell script to download datasets

How to train on Omniglot

Download the two parts of the Omniglot dataset:

Create a omniglot/ folder in the repo, unzip and merge the two files to have the following folder structure:

./train_omniglot.py
...
./omniglot/Alphabet_of_the_Magi/
./omniglot/Angelic/
./omniglot/Anglo-Saxon_Futhorc/
...
./omniglot/ULOG/

Now start training with

python train_omniglot.py log --cuda 0 $HYPERPARAMETERS  # with CPU
python train_omniglot.py log $HYPERPARAMETERS  # with CUDA

where $HYPERPARAMETERS depends on your task and hyperparameters.

Behavior:

  • If no checkpoints are found in log/, this will create a log/ folder to store tensorboard information and checkpoints.
  • If checkpoints are found in log/, this will resume from the last checkpoint.

Training can be interrupted at any time with ^C, and resumed from the last checkpoint by re-running the same command.

Omniglot Hyperparameters

The following set of hyperparameters work decently. They are taken from the OpenAI implementation but are adapted slightly for meta-batch=1.

For 5-way 5-shot (red curve):

python train_omniglot.py log/o55 --classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.001

For 5-way 1-shot (blue curve):

python train_omniglot.py log/o51 --classes 5 --shots 1 --train-shots 12 --meta-iterations 200000 --iterations 12 --test-iterations 86 --batch 10 --meta-lr 0.33 --lr 0.00044

References

reptile-pytorch's People

Contributors

gabrielhuang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

reptile-pytorch's Issues

Updating Meta learner parameters

It's mentioned in the original article that update's rule is : Phi <--- Phi - Eps * (W - Phi)
why didn't you use Eps and why you used (Phi-W)

Epsilon implicitly set to 1

Hi Gabriel,

Thanks a lot for the nice code!

When you add the parameter difference in ReptileModel.point_grad_to, you add the difference directly rather than multiplying with epsilon as in the paper:

p.grad.data.add_(p.data - target_p.data)

Is there any specific reason for this?

Thanks,
Marco

Not Achieving Same Accuracy Results

Hello, I'm running this code from my local machine. It's a direct fork from this repo. The only updates I've made are to make it compatible with Python3. Are you certain you're seeing ~95% accuracy when running with these hyperparams?:
--classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.001

Because I am not. Here is my final log output:
Meta-train
average metaloss 0.9020509621277452
average accuracy 0.7628000104278326

Meta-val
average metaloss 0.9079857150167226
average accuracy 0.7575000119954347

Any ideas why I may be seeing much lower accuracies? Any input would be appreciated. Thanks.

Validation accuracy is higher than training accuracy

Hi, thanks for your implementation!
When I run you code, I found that accuracy of "Meta-Val" dataset starts at a very high level around 0.7
, whereas accuracy of "Meta-Train" set is only 0.5. As the training goes on, this gap reduces but val accuracy is still slightly higher than train accuracy. Is this a normal phenomenon?
IWVPR)WEGAU8N%ISMII`S76

A little confuse about the implement

Hi Gabriel,

Thank you for sharing the code!

I have a little confuse about this implement~

In my opinion, meta-learning is use meta-train set for several update and then merge all loss [calculate on meta-test set ] from different episode and sum over it, and backward the loss on meta network to get the grads and then update the base network use the grads. But in your implement i did't see the meta-test set. It seems that you use the base network's grads minus meta network's grads directly which calculate on the meta-train set.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.