Coder Social home page Coder Social logo

glard / progressive-neural-architecture-search Goto Github PK

View Code? Open in Web Editor NEW

This project forked from titu1994/progressive-neural-architecture-search

0.0 2.0 0.0 126 KB

Implementation of Progressive Neural Architecture Search in Keras and Tensorflow

License: MIT License

Python 100.00%

progressive-neural-architecture-search's Introduction

Progressive Neural Architecture Search with Encoder RNN

Basic implementation of Encoder RNN from Progressive Neural Architecture Search.

  • Uses Keras to define and train children / generated networks, which are found via sequential model-based optimization in Tensorflow, ranked by the Encoder RNN.
  • Define a state space by using StateSpace, a manager which maintains input states and handles communication between the Encoder RNN and the user.
  • Encoder manages the training and evaluation of the Encoder RNN
  • NetworkManager handles the training and reward computation of the children Keras model

Usage

At a high level : For full training details, please see train.py.

# construct a state space (the default operators are from the paper)
state_space = StateSpace(B, # B = number of blocks in each cell
                         operators=None # whether to use custom operators or the default ones from the paper
                         input_lookback_depth=0, # limit number of combined inputs from previous cell
                         input_lookforward_depth=0, # limit number of combined inputs in same cell
                         )

# create the managers
controller = Encoder(tf_session, state_space, B, K)  # K = number of children networks to train after initial step
manager = NetworkManager(dataset, epochs=max_epochs, batchsize=batchsize)

# For `B` number of trials
  actions = controller.get_actions(K)  # get all the children model to train in this trial

  For each `child` in action
    store reward = manager.get_reward(child) in `rewards` list

  encoder.train(rewards)  # train encoder RNN with a surrogate loss function
  encoder.update()  # build next set of children to train in next trial, and sort them

Implementation details

This is a very limited project.

  • It is not a faithful re-implementation of the original paper. There are several small details not incorporated (like bias initialization, actually using the Hc-2 - Hcb-1 values etc)
  • It doesnt have support for skip connections via 'anchor points' etc. (though it may not be that hard to implement it as a special state)
  • Learning rate, number of epochs to train per B_i, regularization strength etc are all random values (which make somewhat sense to me)
  • Single GPU model only. There would need to be a lot of modifications to this for multi GPU training (and I have just 1)

Result

I tried a toy CNN model with 2 CNN cells the a custom search space, train for just 5 epoch of training on CIFAR-10.

The top 5 models are available using the rank_architectures.py script to parse train_history.csv.

Requirements

  • Keras >= 2.1.2
  • Tensorflow-gpu >= 1.2

Acknowledgements

Code somewhat inspired by wallarm/nascell-automl

progressive-neural-architecture-search's People

Contributors

titu1994 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.