Coder Social home page Coder Social logo

jaminfong / darts-multi_gpu Goto Github PK

View Code? Open in Web Editor NEW

This project forked from quark0/darts

31.0 2.0 5.0 4.82 MB

Differentiable architecture search for convolutional and recurrent networks

Home Page: https://arxiv.org/abs/1806.09055

License: Apache License 2.0

Python 100.00%

darts-multi_gpu's Introduction

Differentiable Architecture Search

multi-gpu implementation and gpu utilization modified

--unrolled version not available yet.

Code accompanying the paper

DARTS: Differentiable Architecture Search
Hanxiao Liu, Karen Simonyan, Yiming Yang.
arXiv:1806.09055.

darts

The algorithm is based on continuous relaxation and gradient descent in the architecture space. It is able to efficiently design high-performance convolutional architectures for image classification (on CIFAR-10 and ImageNet) and recurrent architectures for language modeling (on Penn Treebank and WikiText-2). Only a single GPU is required.

Requirements

Python >= 3.5.5, PyTorch >= 0.4

NOTE: It's best to use PyTorch 1.0

Datasets

Instructions for acquiring PTB and WT2 can be found here. While CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.

Pretrained models

The easist way to get started is to evaluate our pretrained DARTS models.

CIFAR-10 (cifar10_model.pt)

cd cnn && python test.py --auxiliary --model_path cifar10_model.pt
  • Expected result: 2.63% test error rate with 3.3M model params.

PTB (ptb_model.pt)

cd rnn && python test.py --model_path ptb_model.pt
  • Expected result: 55.68 test perplexity with 23M model params.

ImageNet (imagenet_model.pt)

cd cnn && python test_imagenet.py --auxiliary --model_path imagenet_model.pt
  • Expected result: 26.7% top-1 error and 8.7% top-5 error with 4.7M model params.

Architecture search (using small proxy models)

To carry out architecture search using 1nd-order approximation, run

cd cnn && python train_search.py      # for conv cells on CIFAR-10

Note the validation performance in this step does not indicate the final performance of the architecture. One must train the obtained genotype/architecture from scratch using full-sized models, as described in the next section.

Also be aware that different runs would end up with different local minimum. To get the best result, it is crucial to repeat the search process with different seeds and select the best cell(s) based on validation performance (obtained by training the derived cell from scratch for a small number of epochs). Please refer to fig. 3 and sect. 3.2 in our arXiv paper.

progress_convolutional_normal progress_convolutional_reduce progress_recurrent

Figure: Snapshots of the most likely normal conv, reduction conv, and recurrent cells over time.

Architecture evaluation (using full-sized models)

To evaluate our best cells by training from scratch, run

cd cnn && python train.py --auxiliary --cutout            # CIFAR-10
cd rnn && python train.py                                 # PTB
cd rnn && python train.py --data ../data/wikitext-2 \     # WT2
            --dropouth 0.15 --emsize 700 --nhidlast 700 --nhid 700 --wdecay 5e-7
cd cnn && python train_imagenet.py --auxiliary            # ImageNet

Customized architectures are supported through the --arch flag once specified in genotypes.py.

The CIFAR-10 result at the end of training is subject to variance due to the non-determinism of cuDNN back-prop kernels. It would be misleading to report the result of only a single run. By training our best cell from scratch, one should expect the average test error of 10 independent runs to fall in the range of 2.76 +/- 0.09% with high probability.

cifar10 ptb ptb

Figure: Expected learning curves on CIFAR-10 (4 runs), ImageNet and PTB.

Visualization

Package graphviz is required to visualize the learned cells

python visualize.py DARTS

where DARTS can be replaced by any customized architectures in genotypes.py.

Citation

If you use any part of this code in your research, please cite our paper:

@article{liu2018darts,
  title={DARTS: Differentiable Architecture Search},
  author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
  journal={arXiv preprint arXiv:1806.09055},
  year={2018}
}

darts-multi_gpu's People

Contributors

jaminfong avatar quark0 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

darts-multi_gpu's Issues

'Network' object has no attribute '_loss'`

Hello! Have you completed the code?
I cloned your project. I tried to use multiple gpus training and received the following log.

pytorch 0.4

log:
Traceback (most recent call last): File "train_search.py", line 202, in <module> main() File "train_search.py", line 127, in main train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr) File "train_search.py", line 155, in train architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled) File "/darts-multi_gpu/cnn/architect.py", line 36, in step self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer) File "/darts-multi_gpu/cnn/architect.py", line 48, in _backward_step_unrolled unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer) File "/darts-multi_gpu/cnn/architect.py", line 23, in _compute_unrolled_model loss = self.model.module._loss(input, target) File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 518, in __getattr__ type(self).__name__, name)) AttributeError: 'Network' object has no attribute '_loss'

Very slow when multi-gpu

Environmental for requests
numpy==1.15.4
graphviz==0.8.4
torch==1.0.0
torchvision==0.2.1
tensorboard==1.13.0
tensorboardX==1.6

when I use 8 gpus(GPU-Util about 4-9%), time will be very very slower than only use 1 gpu(origin code, GPU-Util about 30-90%)?

it is time perf line by line
image

Multi GPU setup

Thanks for your code !

I got an error with running original code :
Traceback (most recent call last): File "train_search.py", line 203, in <module> main() File "train_search.py", line 82, in main arch_params = list(map(id, model.module.arch_parameters())) File "/opt/conda/envs/darts-gpus/lib/python3.5/site-packages/torch/nn/modules/module.py", line 535, in __getattr__ type(self).__name__, name)) AttributeError: 'Network' object has no attribute 'module'

so that I modify train_search.py line 64 as following , and it could run .
gpus = [int(i) for i in args.gpu.split(',')]
as
gpus = [0,1,2,3] (I have 4 V100s)

How to set multi-gpu?

Thank you for sharing. I didn't find the place that you changed the code to fit multi GPU. Did you update your code ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.