Multi-way Backpropagation for Training Compact Deep Neural Networks
Training code for Multi-way BP. Both the Pytorch and Torch implementations are available.
PyTorch Implementation
Requirements
- Pytorch=1.0.0
- python=2.7
Train Method
-
Prepare data
Download the training data (e.g., CIFAR-10) and put them to your own directory.
-
Train deep models with Multi-way BP
cd multiwaybp-pytorch
python main.py
You may refer to options.py for more argument.
Some Key arguments:
- dataPath : path for loading data set
- save_path : path for saving model file
- nGPU : number of GPUs to use (support multi-gpu) (Default: 1)
- netType : choose the network type as baseline model
- pivotSet : where to add aux-classifier
Torch Implementation
Requirements
See the installation instructions for a step-by-step guide.
If you already have Torch installed, update the luarocks nn
, cunn
and cudnn
.
Training Method
-
Prepare data
Download the training data (e.g., CIFAR-10) and put them to your own directory.
-
Train deep models with Multi-way BP
cd multiwaybp-torch
th train.lua
Testing Method
- Test pre-trained models
- CIFAR10-MwResNet-56-2
- CIFAR10-MwResNet-56-5
- CIFAR10-MwResNet-26-2/10
- CIFAR100-MwResNet-56-2
- CIFAR100-MwResNet-56-5
- CIFAR100-MwResNet-26-2/10
To test the performance of the MwResNet models, please download the pre-trained models and move them into the directory ./pretrained
.
Then you can run the script test.lua. For example:
th test.lua -dataset cifar10 -model cifar10-mwresnet-26-2-wide-10
- Test intermediate models
During the training, Multi-way BP simultaneously generates multiple models with different depth. Take CIFAR10-MwResNet-56-5 (including the ''auxiliary outputs'' file) for example:
Intermediate models | Depth | #Params |
---|---|---|
model-15 | 15 | 0.03M |
model-25 | 25 | 0.09M |
model-35 | 35 | 0.18M |
model-45 | 45 | 0.48M |
model-56 | 56 | 0.85M |
To test the intermediate models, simply run the script intermediate.lua.
th intermediate.lua -dataset cifar10 -model cifar10-mwresnet-56-5 -outputs cifar10-mwresnet-56-5-outputs