Coder Social home page Coder Social logo

drnas's Introduction

DrNAS

About

Code accompanying the paper
ICLR'2021: DrNAS: Dirichlet Neural Architecture Search paper
Xiangning Chen*, Ruochen Wang*, Minhao Cheng*, Xiaocheng Tang, Cho-Jui Hsieh

This code is based on the implementation of NAS-Bench-201 and PC-DARTS.

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random variables, modeled by Dirichlet distribution. With recently developed pathwise derivatives, the Dirichlet parameters can be easily optimized with gradient-based optimizer in an end-to-end manner. This formulation improves the generalization ability and induces stochasticity that naturally encourages exploration in the search space. Furthermore, to alleviate the large memory consumption of differentiable NAS, we propose a simple yet effective progressive learning scheme that enables searching directly on large-scale tasks, eliminating the gap between search and evaluation phases. Extensive experiments demonstrate the effectiveness of our method. Specifically, we obtain a test error of 2.46% for CIFAR-10, 23.7% for ImageNet under the mobile setting. On NAS-Bench-201, we also achieve state-of-the-art results on all three datasets and provide insights for the effective design of neural architecture search algorithms.

Results

On NAS-Bench-201

The table below shows the test accuracy on NAS-Bench-201 space. We achieve the state-of-the-art results on all three datasets. On CIFAR-100, DrNAS even achieves the global optimal with no variance!

Method CIFAR-10 (test) CIFAR-100 (test) ImageNet-16-120 (test)
ENAS 54.30 ± 0.00 10.62 ± 0.27 16.32 ± 0.00
DARTS 54.30 ± 0.00 38.97 ± 0.00 18.41 ± 0.00
SNAS 92.77 ± 0.83 69.34 ± 1.98 43.16 ± 2.64
PC-DARTS 93.41 ± 0.30 67.48 ± 0.89 41.31 ± 0.22
DrNAS (ours) 94.36 ± 0.00 73.51 ± 0.00 46.34 ± 0.00
optimal 94.37 73.51 47.31

For every search process, we sample 100 architectures from the current Dirichlet distribution and plot their accuracy range along with the current architecture selected by Dirichlet mean (solid line). The figure below shows that the accuracy range of the sampled architectures starts very wide but narrows gradually during the search phase. It indicates that DrNAS learns to encourage exploration at the early stages and then gradually reduces it towards the end as the algorithm becomes more and more confident of the current choice. Moreover, the performance of our architectures can consistently match the best performance of the sampled architectures, indicating the effectiveness of DrNAS.

On DARTS Space (CIFAR-10)

DrNAS achieves an average test error of 2.46%, ranking top amongst recent NAS results.

Method Test Error (%) Params (M) Search Cost (GPU days)
ENAS 2.89 4.6 0.5
DARTS 2.76 ± 0.09 3.3 1.0
SNAS 2.85 ± 0.02 2.8 1.5
PC-DARTS 2.57 ± 0.07 3.6 0.1
DrNAS (ours) 2.46 ± 0.03 4.1 0.6

On DARTS Space (ImageNet)

DrNAS can perform a direct search on ImageNet and achieves a top-1 test error below 24.0%!

Method Top-1 Error (%) Params (M) Search Cost (GPU days)
DARTS* 26.7 4.7 1.0
SNAS* 27.3 4.3 1.5
PC-DARTS 24.2 5.3 3.8
DSNAS 25.7 - -
DrNAS (ours) 23.7 5.7 4.6

* not a direct search

Usage

Architecture Search

Search on NAS-Bench-201 Space: (3 datasets to choose from)

  • Data preparation: Please first download the 201 benchmark file and prepare the api follow this repository.

  • cd 201-space && python train_search.py

  • With Progressively Pruning: cd 201-space && python train_search_progressive.py

Search on DARTS Space:

  • Data preparation: For a direct search on ImageNet, we follow PC-DARTS to sample 10% and 2.5% images for earch class as train and validation.

  • CIFAR-10: cd DARTS-space && python train_search.py

  • ImageNet: cd DARTS-space && python train_search_imagenet.py

Architecture Evaluation

  • CIFAR-10: cd DARTS-space && python train.py --cutout --auxiliary

  • ImageNet: cd DARTS-space && python train_imagenet.py --auxiliary

Reference

If you find this code useful in your research please cite

@inproceedings{chen2021drnas,
    title={Dr{\{}NAS{\}}: Dirichlet Neural Architecture Search},
    author={Xiangning Chen and Ruochen Wang and Minhao Cheng and Xiaocheng Tang and Cho-Jui Hsieh},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=9FWas6YbmB3}
}

Related Publications

drnas's People

Contributors

xiangning-chen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

drnas's Issues

Raw_id in configure_optimizer

Hi!

Thank you for this code! While trying the code I run into an error while running configure_optimizer(optimizer_old, optimizer_new): stating:

state_old = optimizer_old.state_dict()['state'][p.raw_id]
KeyError: 140435163178672

In general, I do not see a way how the raw_id can be mapped to a dictionary key?

Thanks for any pointers/help!

Error in cinfigure_optimizer

Hey i'm trying to run /201space/train-search-progressive but there seems to be a problem in the configure_optimizer function and it encounters the following error can you help me figure this out?

Traceback (most recent call last):
File "/content/drive/MyDrive/DrNAS-master/201-space/train_search_progressive.py", line 336, in
main()
File "/content/drive/MyDrive/DrNAS-master/201-space/train_search_progressive.py", line 251, in main
weight_decay=args.weight_decay))
File "/content/drive/MyDrive/DrNAS-master/net2wider.py", line 84, in configure_optimizer
state_old = optimizer_old.state_dict()['state'][p.raw_id]
KeyError: 140076590179920

reading this link I believe the problem might be because of the pytorch version i'm using which is 1.10 I don't know what version is required for your code

Key Error

After the 25th epoch, I will receive a KeyError:

Traceback (most recent call last): File "train_search.py", line 238, in <module> main() File "train_search.py", line 151, in main optimizer = configure_optimizer(optimizer, torch.optim.SGD( File "../net2wider.py", line 84, in configure_optimizer state_old = optimizer_old.state_dict()['state'][p.raw_id] KeyError: 139737672243200

Any idea how to solve this? Is that a version conflict issue?

Reproducing results

Hi,

Thank you for the great work and for open sourcing your code!

I have tried reproducing the results for NB201 and the DARTS search space.

On the NB201 search space using cd 201-space && python train_search.py, I can reproduce the results from your paper as follows:

  • cifar10 : 94.360000
  • Imagenet16-120 : 46.340000

For CIFAR100 on NASBench201 search space however I could only obtain 70.88 contrary to the reported 73.51. On the DARTS search space when searching on CIFAR-10 using cd DARTS-space && python train_search.py I am not able to obtain the same genotype as the one mentioned in the repo. With the new genotype I get an error of 2.89±0.091.

Secondly when performing evaluation on the DARTS search space using cd DARTS-space && python train.py --cutout --auxiliary and the DrNAS_cifar10 genotype from the repo, I could obtain an error of 2.67±0.090 which is higher than the 2.46 ± 0.03 reported in the paper.

Any help in replicating the results on the DARTS search space would be greatly appreciated! Thanks!

Mobile Setting for ImageNet

Thanks for the great work!

  1. Is train_search_imagenet.py under the mobile setting?
  2. Could you explain the details about the mobile setting?
    1. I only found "input image size is 224×224" and "the number of multiply-add operations in the model is restricted to be less than 600M". Are these all?
    2. Is this setting applied during search or train-from-scratch, and how to apply them? I could not see any FLOPs constraint during search or train-from-scratch.

Thank you!

Reproducing nb201 results

Hi,

Thank you for the great work!

I was trying to reproduce the nb201 results, however, while I was able to reproduce the results for Cifar10 and Imagenet16-120, I couldn't do the same for Cifar100.
I am running the train_search.py file, and just changing the dataset argument to run.

The results I am getting for 100 epochs using the default hyperparameters provided in the script are:

  • NB201 test accuracy cifar10: 94.36%
  • NB201 test accuracy cifar100: 70.47%
  • NB201 test accuracy imagenet16-120: 46.34%

Any help would be appreciated!

Version issue and replication

Hello!

I am trying to replicate the results of DrNAS for NASbench201 space in Table 4 of your paper. I want to generate the test accuracy for CIFAR10 (94.36% in table 4 of the paper).

  1. Are you using progressive training (train_search_progressive.py) to get those results in Table 4? Additionally you have provided instructions for the evaluation phase in the DARTS space. But what script do you use for the eval phase in NASbench201?
  2. Could you share which version of pytorch, cudatoolkit, torchvision, tensorboard etc you used? It would be great if you have a screenshot for the environment you used to get the results.

Thanks a lot for your help!

Reproducing the results

Hi author,

I am having difficulty reproducing the results on cifar-10. The paper claimed test error of 2.46+-0.03 with 600 epochs, but when I am evaluating with the provided 'DrNAS_cifar10' genotype, I only get accuracy 94.86 with 600 epochs and getting accuracy of 97.44 with 1200 epochs. It seems like the default parameter in the code matches the parameter claimed in the paper, or did I miss something here?

I saw in PC-DARTS github page that there is much randomness in training cifar10 so the result is not stable, is it also the case here?

Thank you for your response.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.