Coder Social home page Coder Social logo

lmbxmu / clr-rnf Goto Github PK

View Code? Open in Web Editor NEW
10.0 1.0 1.0 535 KB

Pytorch implementation of our paper (TNNLS) -- Pruning Networks with Cross-Layer Ranking & k-Reciprocal Nearest Filters

Home Page: https://arxiv.org/abs/2202.07190

Python 100.00%
network-pruning acceleration compression

clr-rnf's Introduction

Pruning Networks with Cross-Layer Ranking & k-Reciprocal Nearest Filters

Pytorch implementation of our paper under review -- Pruning Networks with Cross-Layer Ranking & k-Reciprocal Nearest Filters

Running code

You can run the following code to prune model on CIFAR-10:

python cifar.py 
--arch vgg_cifar 
--cfg vgg16 
--data_path /data/cifar 
--job_dir ./experiment/cifar/vgg_1 
--pretrain_model /home/pretrain/vgg16_cifar10.pt 
--lr 0.01 
--lr_decay_step 50 100 
--weight_decay 0.005  
--num_epochs 150 
--gpus 0
--pr_target 0.7 
--graph_gpu

You can run the following code to prune resnets on ImageNet:

python imagenet.py 
--dataset imagenet 
--data_path /data/ImageNet/ 
--pretrain_model /data/model/resnet50.pth 
--job_dir /data/experiment/resnet50 
--arch resnet_imagenet 
--cfg resnet50 
--lr 0.1 
--lr_type step
--num_epochs 90 
--train_batch_size 256 
--weight_decay 1e-4 
--gpus 0 1 2 
--pr_target 0.7 
--graph_gpu

You can run the following code to prune mobilenet_v1 on ImageNet:

python imagenet.py 
--dataset imagenet 
--arch mobilenet_v1
--cfg mobilenet_v1 
--data_path /media/disk2/zyc/ImageNet2012 
--pretrain_model ./pretrain/checkpoints/mobilenet_v1.pth.tar 
--job_dir ./experiment/imagenet/mobilenet_v1 
--lr 0.1 
--lr_type cos
--weight_decay 4e-5 
--num_epochs 150 
--gpus 0  
--train_batch_size 256 
--eval_batch_size 256 
--pr_target 0.58
--graph_gpu

You can run the following code to prune mobilenet_v2 on ImageNet:

python imagenet.py 
--dataset imagenet 
--arch mobilenet_v2 
--cfg mobilenet_v2 
--data_path /media/disk2/zyc/ImageNet2012 
--pretrain_model ./pretrain/checkpoints/mobilenet_v2.pth.tar 
--job_dir ./experiment/imagenet/mobilenet_v2 
--lr 0.1 
--lr_type cos
--weight_decay 4e-5 
--num_epochs 150 
--gpus 0  
--train_batch_size 256 
--eval_batch_size 256 
--pr_target 0.25
--graph_gpu

You can run the following code to get FLOPs prune ratio under a given parameters prune target:

python get_flops.py 
--arch resnet_imagenet 
--cfg resnet50 
--pretrain_model /media/disk2/zyc/prune_result/resnet_50/pruned_checkpoint/resnet50-19c8e357.pth 
--job_dir ./experiment/imagenet/resnet50_flop 
--graph_gpu 
--pr_target 0.1

You can run the following code to compare the loss between graph,Kmeans & random:

python cal_graph_loss.py 
--arch vgg_cifar 
--cfg vgg16 
--data_path /data/cifar 
--job_dir ./experiment/vgg
--pretrain_model pretrain/vgg16_cifar10.pt 
--gpus 0 
--graph_gpu

You can run the following code to test our model:

python test.py
--arch resnet_imagenet 
--cfg resnet50 
--data_path /media/disk2/zyc/ImageNet2012 
--resume ./pretrain/checkpoints/model_best.pt 
--pretrain_model /media/disk2/zyc/prune_result/resnet_50/pruned_checkpoint/resnet50-19c8e357.pth 
--pr_target 0.44 
--job_dir ./experiment/imagenet/test 
--eval_batch_size 256

CIFAR-10

Full Model Parameter(PR) Flops(PR) lr_type lightening Accuracy Model
VGG-16 (Baseline) 14.73M(0.0%) 314.04M(0.0%) step 93.02% pre-trained
VGG-16-0.86 0.74M(94.95%) 81.31M(74.11%) step 93.32% pruned
ResNet-56 (Baseline) 0.85M(0.0%) 126.56M(0.0%) step 93.26% pre-trained
ResNet-56-0.56 0.39M(54.47%) 55.26M(56.34%) step 93.27% pruned
ResNet-110 (Baseline) 1.73M(0.0%) 254.99M(0.0%) step 93.53% pre-trained
ResNet-110-0.69 0.53M(69.14%) 86.80M(65.96%) step 93.71% pruned
GoogLeNet (Baseline) 6.17M(0.0%) 1529.43M(0.0%) step 95.03% pre-trained
GoogLeNet-0.91 2.18M(64.70%) 491.54M(67.86%) step 94.85% pruned

ImageNet

Architecture Parameter(PR) Flops(PR) lr_type lightening Top1-Acc Top5-Acc Model
ResNet-50(Baseline) 25.56M(0.0%) 4113.56M(0.0%) step 76.01% 92.96% pre-trained
ResNet-50-0.52 6.90M(72.98%) 931.02M(77.37%) step 71.112% 90.424% pruned
ResNet-50-0.44 9.00M(64.77%) 1227.23M(70.17%) step 72.656% 91.085% pruned
ResNet-50-0.2 16.92M(33.80%) 2445.83M(40.54%) step 74.851% 92.305% pruned
ResNet-50-0.44 9.00M(64.77%) 1227.23M(70.17%) cos 73.344% 91.271% pruned
ResNet-50-0 [pruned]

Other Arguments

optional arguments:
  -h, --help            show this help message and exit
  --gpus GPUS [GPUS ...]
                        Select gpu_id to use. default:[0]
  --dataset DATASET     Select dataset to train. default:cifar10
  --data_path DATA_PATH
                        The dictionary where the input is stored.
                        default:/home/data/cifar10/
  --job_dir JOB_DIR     The directory where the summaries will be stored.
                        default:./experiments
  --arch ARCH           Architecture of model. default:resnet_imagenet. optional:resnet_cifar/mobilenet_v1/mobilenet_v2
  --cfg CFG             Detail architecuture of model. default:resnet56. optional:resnet110/18/34/50/101/152 mobilenet_v1/mobilenet_v2
  --graph_gpu           Use gpu to graph the filters or not. default:False
  --init_method INIT_METHOD
                        Initital method of pruned model. default:direct_project. optional:random_project
  --pr_target           Target prune ratio of parameters 
  --lr_type             lr scheduler. default: step. optional:exp/cos/step/fixed
  --criterion           Loss function. default:Softmax. optional:SmoothSoftmax
  --graph_method        Method to recontruct the graph of filters. default:knn other:kmeans/random
  --resume              Continue training from specific checkpoint. For example:./experiment/imagenet/resnet50_redidual/checkpoint/model_last.pt
  --use_dali            If this parameter exists, use dali module to load ImageNet data.

clr-rnf's People

Contributors

lmbxmu avatar zyxxmu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

mldl

clr-rnf's Issues

resnet残差网络的长尾分布?

您好,我对您的这一工作很感兴趣。
我在vggNet上验证了长尾分布,与论文中一致。
但是由于残差链接的存在,长尾分布在resnet类型网络中也存在吗?
我打印了resnet56的权重,发现浅层深层并没有什么规律。
请问长尾分布是只适用于VGGNet吗,或者说在resnet上是怎么泛化的呢。
期待您的回复。

resnet剪枝

请问手动为每个残差块的输入和输出层保留相同的修剪率的代码是哪部分呢?

imagenet.py

你好,请问为什么imagenet.py中graph_mobilenet_v1和graph_mobilenet_v2没有定义,报错了?

版本

请问你们有各个包的版本要求吗?

resnet_imagenet.py

请问在model/resnet_imagenet.py中第174行的inplanes是不是应该是self.inplanes???

flops_cfg and flops_lambda

请问在训练文件中的flops_cfg和flops_lambda这些值是怎么确定下来的?如果应用于新的网络,怎么获取flops_cfg和flops_lambda????

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.