Coder Social home page Coder Social logo

stjordanis / pba Goto Github PK

View Code? Open in Web Editor NEW

This project forked from arcelien/pba

0.0 2.0 0.0 912 KB

Efficient Learning of Augmentation Policy Schedules

Home Page: http://arxiv.org/abs/1905.05393

License: Apache License 2.0

Python 96.45% Shell 3.55%

pba's Introduction

Population Based Augmentation (PBA)

Table of Contents

  1. Introduction
  2. Getting Started
  3. Reproduce Results
  4. Run PBA Search
  5. Citation

Introduction

Population Based Augmentation (PBA) is a algorithm that quickly and efficiently learns data augmentation functions for neural network training. PBA matches state-of-the-art results on CIFAR with one thousand times less compute, enabling researchers and practitioners to effectively learn new augmentation policies using a single workstation GPU.

This repository contains code for the work "Population Based Augmentation: Efficient Learning of Augmentation Schedules" (http://arxiv.org/abs/1905.05393) in TensorFlow and Python 2. It includes training of models with the reported augmentation schedules and discovery of new augmentation policy schedules.

See below for a visualization of our augmentation strategy.

Getting Started

Code currently only supports Python 2.

Install requirements

pip install -r requirements.txt

Download CIFAR-10/CIFAR-100 datasets

bash datasets/cifar10.sh
bash datasets/cifar100.sh

Reproduce Results

Dataset Model Test Error (%)
CIFAR-10 Wide-ResNet-28-10 2.58
Shake-Shake (26 2x32d) 2.54
Shake-Shake (26 2x96d) 2.03
Shake-Shake (26 2x112d) 2.03
PyramidNet+ShakeDrop 1.46
Reduced CIFAR-10 Wide-ResNet-28-10 12.82
Shake-Shake (26 2x96d) 10.64
CIFAR-100 Wide-ResNet-28-10 16.73
Shake-Shake (26 2x96d) 15.31
PyramidNet+ShakeDrop 10.94
SVHN Wide-ResNet-28-10 1.18
Shake-Shake (26 2x96d) 1.13
Reduced SVHN Wide-ResNet-28-10 7.83
Shake-Shake (26 2x96d) 6.46

Scripts to reproduce results are located in scripts/table_*.sh. One argument, the model name, is required for all of the scripts. The available options are those reported for each dataset in Tables 1-4 of the paper, among the choices: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net. Hyperparamaters are also located inside each script file.

For example, to reproduce CIFAR-10 results on Wide-ResNet-28-10:

bash scripts/table_1_cifar10.sh wrn_28_10

To reproduce Reduced SVHN results on Shake-Shake (26 2x96d):

bash scripts/table_4_svhn.sh rsvhn_ss_96

A good place to start is Reduced SVHN on Wide-ResNet-28-10 which can complete in under 10 minutes on a Titan XP GPU reaching 91%+ test accuracy.

Running the larger models on 1800 epochs may require multiple days of training. For example, CIFAR-10 PyramidNet+ShakeDrop takes around 9 days on a Tesla V100 GPU.

Run PBA search on Wide-ResNet-40-2 with the file scripts/search.sh. One argument, the dataset name, is required. Choices are rsvhn or rcifar10.

A partial GPU size is specified to launch multiple trials on the same GPU. Reduced SVHN takes around an hour on a Titan XP GPU, and Reduced CIFAR-10 takes around 5 hours.

CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn

The resulting schedules used in search can be retreived from the Ray result directory, and the log files can be converted into policy schedules with the parse_log() function in pba/utils.py. For example, policy schedule learned on Reduced CIFAR-10 over 200 epochs is split into probability and magnitude hyperparameter values (the two values for each augmentation operation are merged) and visualized below:

Probability Hyperparameters over Time Magnitude Hyperparameters over Time

Citation

If you use PBA in your research, please cite:

@inproceedings{ho2019pba,
  title     = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
  author    = {Daniel Ho and
               Eric Liang and
               Ion Stoica and
               Pieter Abbeel and
               Xi Chen
  },
  booktitle = {ICML},
  year      = {2019}
}

pba's People

Contributors

arcelien avatar cclauss avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.