lmbxmu / hrank Goto Github PK
View Code? Open in Web Editor NEWPytorch implementation of our paper accepted by CVPR 2020 (Oral) -- HRank: Filter Pruning using High-Rank Feature Map
Home Page: https://128.84.21.199/abs/2002.10179
Pytorch implementation of our paper accepted by CVPR 2020 (Oral) -- HRank: Filter Pruning using High-Rank Feature Map
Home Page: https://128.84.21.199/abs/2002.10179
Thanks for your excellent work! When I try to reproduce HRank in other networks, I found an interesting result:
How do you deal with feature maps comparison in later layers if you don't upsample the original input images? Wouldn't the feature maps at the last convolutional layers be too small?
Thanks for your sharing.How to get the compress rate for different models or different FLOPs?
It has been a while since HRank published. Let me start off by saying thank you for sharing this interesting piece of work and bringing in a novel perspective in the pruning realm. However, as we were trying to replicate the HRank results for benchmarking, we noticed the following issue.
By the following lines, it looks like for every new epoch, the checkpoint with the best test acc is being automatically loaded, then the training resumes:
Lines 232 to 244 in 33050a1
Lines 305 to 306 in 33050a1
We also confirmed it empirically by checking acc and printing out a portion of conv tensor:
While I understand that it is common and acceptable practice to report the epoch with the best test acc [1], training every epoch upon the checkpoint with the best test acc sounds like a potential data leak — as it is using test set info to determine operations. It looks like HRank may perform reasonably well without this setting (i.e. by just continuing training upon the latest epoch). Is this by accident?
[1] Li et al. Pruning Filters for Efficient ConvNets. ICLR 2017
请问论文中的公式(5)该如何理解,为何期望值约等号后没有1/g呢?
可以用你的压缩方法去压缩别的网络吗,比如HRNET
对您的研究非常感兴趣,
但是看太懂您的代码,请问完成剪裁是那一部分
您好!
请问您的代码剪枝后生成的模型是不是并没有完全删去通道,就是说并没有将模型的结构变化,而是把需要剪掉的通道的卷积权重置零了呢?
如果是的话,那是不是剪枝后的模型保存的占用空间与剪枝前的是一样的?那参数和FLOPS的现存量是手动按照剪枝率计算的嘛?
谢谢!#因为我在自己做resnet剪枝的时候,想要做到结构变化,但是由于downsample的存在,想减每一层有点问题,不知道该怎么处理,所以想确认下您的代码是否实现了完全剪枝改变模型结构
谢谢!!
您好!我有两个问题:
你好,我使用您的代码cal_flops_params.py(分别在HRank和HRankPlus)测试VGG-16未剪枝前的网络计算量和参数量,得到的结果:314.572M(FLOPs)、14.992M(Params),但是您在论文的结果是:313.73M(FLOPs)、14.98M(Params),想请教下关于结果差距的原因?
Hi, @lmbxmu ,请问我应该先运行哪个文件呢?rank_generation.py需要用到pre-trained的权重(通过main.py训练得到),但是main.py又需要rank_conv*.npy(通过rank_generation得到)?
I'd like to use your code (rank_generation.py and main.py) for pruning with ResNet-50 on ImageNet.
However, it doesn't work and I've got an error
File "", line 1, in
File "/ssd7/skyeom/anaconda3/envs/py38/lib/python3.8/site-packages/torch/serialization.py", line 585, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/ssd7/skyeom/anaconda3/envs/py38/lib/python3.8/site-packages/torch/serialization.py", line 740, in _legacy_load
return legacy_load(f)
File "/ssd7/skyeom/anaconda3/envs/py38/lib/python3.8/site-packages/torch/serialization.py", line 669, in legacy_load
args = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, '\x00'.
Is the pth file (resnet50-19c8e357.pth) that we can download in the link correct one?
Many thanks.
Best regards,
Seul-Ki
Hi,
I read the paper on CVPR this year and find this excellent work. Thanks!
I am curious about the detect model like YOLO or FastRCNN on this method. Do they achieve the same results?
By the way, could you share the model for detection?
Thanks
您好!
关于您论文中图2,请问我这样的理解正确嘛?
“横坐标为当前卷积层经过bn层再经过relu后输出的特征图的序列号,纵坐标为第1、5、10、20、30、40、50组batch,其中的颜色深浅代表特征图秩的大小”
另外可以请教您画图2的代码嘛?非常感谢!!
Recently I saw this excellent work in CVPR. I would like to ask how to prune my own object tracking model with your pruning algorithm.It would be great if there were a project to prune the tracking algorithm.thank you!
Hi,
Thanks for your awesome work. I was wondering - the per-layer pruning rates are hard-coded in the command line in your examples, but what is a non-manual way to determine these? Is there any code in your repo dedicated to searching for these per-layer pruning rates or did you tune them manually?
Thanks.
@lmbxmu
Hi ,thank you your work,it's great~
I want to ask you ,the detect model (yolo ssd) to use your code ,is OK ? same?
为什么我自己训练出来的模型,在用你代码里面的hook计算rank的时候,每个通道的feature map的rank都差不多,不像你存下来的npy那样有明显的从高到低的排序,请问这是什么原因造成的呢?
Hi~Thanks for your great work!
In main.py, the code:
if args.arch=='resnet_50':
skip_list=[1,5,8,11,15,18,21,24,28,31,34,37,40,43,47,50,53]
if cov_id+1 not in skip_list:
continue
else:
pruned_checkpoint = torch.load(
args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(53) + '.pt')
net.load_state_dict(pruned_checkpoint['state_dict'])
Why the model load cov53.pt?
Hi~Thanks for your great work!
In main.py, the code:
if args.arch=='resnet_50':
skip_list=[1,5,8,11,15,18,21,24,28,31,34,37,40,43,47,50,53]
if cov_id+1 not in skip_list:
continue
else:
pruned_checkpoint = torch.load(
args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(53) + '.pt')
net.load_state_dict(pruned_checkpoint['state_dict'])
Why the model load cov53.pt?
您好,我想请问,我在terminal 指定--gpu 0,1 时,为何在
checkpoint = torch.load(args.resume, map_location='cuda'+args.gpu)
抛出异常:invalid literal for int() with base 10:'0,1'。
我该如何使用这个参数,才能使用双卡运行这个文件呢?
你好,请问可以把你提供的预训练模型上传到百度云上吗?我在国内,登不了谷歌,无法下载那些预训练模型,如果比较麻烦的话,可以先把resnet56的模型给我,谢谢哈
请问你在剪枝训练过程中,学习率的设置是和baseline保持一致的吗?比如正常vgg16网络训练就是0.01吗?
您好!
对于vgg16,您readme里所给的参数设置是compress_rate [0.95]+[0.5]*6+[0.9]*4+[0.8]*2 ,我使用您的程序算出的Flops和Params和readme中的数据相符合。
但是我用自己验算了一下(我考虑了卷积层和全连接层),得出的Flops比这个低了的很多(大约42M,少了约60M),不知是哪里出了问题,我想是否能麻烦您手动或者使用别的程序算一下?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.