Coder Social home page Coder Social logo

lam1360 / yolov3-model-pruning Goto Github PK

View Code? Open in Web Editor NEW
1.7K 42.0 433.0 80 KB

在 oxford hand 数据集上对 YOLOv3 做模型剪枝(network slimming)

License: MIT License

Shell 11.27% Python 88.73%
yolov3 model-pruning hand-detection channel-pruning object-detection

yolov3-model-pruning's Introduction

YOLOv3-model-pruning

用 YOLOv3 模型在一个开源的人手检测数据集 oxford hand 上做人手检测,并在此基础上做模型剪枝。对于该数据集,对 YOLOv3 进行 channel pruning 之后,模型的参数量、模型大小减少 80% ,FLOPs 降低 70%,前向推断的速度可以达到原来的 200%,同时可以保持 mAP 基本不变。

环境

Python3.6, Pytorch 1.0及以上

YOLOv3 的实现参考了 eriklindernoren 的 PyTorch-YOLOv3 ,因此代码的依赖环境也可以参考其 repo

数据集准备

  1. 下载数据集,得到压缩文件
  2. 将压缩文件解压到 data 目录,得到 hand_dataset 文件夹
  3. 在 data 目录下执行 converter.py,生成 images、labels 文件夹和 train.txt、valid.txt 文件。训练集中一共有 4807 张图 片,测试集中一共有 821 张图片

正常训练(Baseline)

python train.py --model_def config/yolov3-hand.cfg

剪枝算法介绍

本代码基于论文 Learning Efficient Convolutional Networks Through Network Slimming (ICCV 2017) 进行改进实现的 channel pruning算法,类似的代码实现还有这个 yolov3-network-slimming。原始论文中的算法是针对分类模型的,基于 BN 层的 gamma 系数进行剪枝的。

剪枝算法的大概步骤

以下只是算法的大概步骤,具体实现过程中还要做 s 参数的尝试或者需要进行迭代式剪枝等。

  1. 进行稀疏化训练

    python train.py --model_def config/yolov3-hand.cfg -sr --s 0.01
  2. 基于 test_prune.py 文件进行剪枝,得到剪枝后的模型

  3. 对剪枝后的模型进行微调

    python train.py --model_def config/prune_yolov3-hand.cfg -pre checkpoints/prune_yolov3_ckpt.pth

剪枝前后的对比

  1. 下图为对部分卷积层进行剪枝前后通道数的变化:

    部分卷积层的通道数大幅度减少

  2. 剪枝前后指标对比:

    参数数量 模型体积 Flops 前向推断耗时(2070 TI) mAP
    Baseline (416) 61.5M 246.4MB 32.8B 15.0 ms 0.7692
    Prune (416) 10.9M 43.6MB 9.6B 7.7 ms 0.7722
    Finetune (416) 同上 同上 同上 同上 0.7750

    加入稀疏正则项之后,mAP 反而更高了(在实验过程中发现,其实 mAP上下波动 0.02 是正常现象),因此可以认为稀疏训练得到的 mAP 与正常训练几乎一致。将 prune 后得到的模型进行 finetune 并没有明显的提升,因此剪枝三步可以直接简化成两步。剪枝前后模型的参数量、模型大小降为原来的 1/6 ,FLOPs 降为原来的 1/3,前向推断的速度可以达到原来的 2 倍,同时可以保持 mAP 基本不变。需要明确的是,上面表格中剪枝的效果是只是针对该数据集的,不一定能保证在其他数据集上也有同样的效果

  3. 剪枝后模型的测试:

    Prune 模型的权重已放在百度网盘上 (提取码: gnzx),可以通过执行以下代码进行测试:

    python test.py --model_def config/prune_yolov3-hand.cfg --weights_path weights/prune_yolov3_ckpt.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01

yolov3-model-pruning's People

Contributors

lam1360 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

yolov3-model-pruning's Issues

稀疏权重

大佬你好,感谢你的工作,方便提供一下你稀疏训练完但还没剪枝的权重吗

计算MAP的时候报错:list index out of range

你好,我在训练自己的训练集的时候,在计算mAP时提示IndexError,不知道该如何解决这个问题呢?
Traceback (most recent call last):
File "train.py", line 205, in
ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
IndexError: list index out of range
image

求助 训练自己数据集出现问题 是我标签的原因吗

RuntimeError: cuda runtime error (59) : device-side assert triggered at C:/w/1/s/tmp_conda_3.6_041836/conda/conda-bld/pytorch_1556684464974/work/aten/src\THC/THCTensorMathCompareT.cuh:69

上面是报错内容,我的标签格式如下:
0 0.54921875 0.3546875 0.009375000000000001 0.026041666666666668
0 0.66796875 0.4109375 0.009375000000000001 0.026041666666666668

稀疏化训练的问题

您好,请问我在稀疏化训练的时候,precision逐渐从0增到0.5-0.6左右(第九个epoch),逐渐又开始下降到0.3,现在还在训练,不知道这种现象正常吗?或者稀疏化训练precision正常变化趋势是怎样的。(补充:我的数据集较大,有近两万张图片)

求整理好之后的代码

你好,感谢你的分享!请问最终整理好的代码什么时候会开放出来呀?现在的代码有好几处报错的地方,好像没法稀疏训练吧?

situable prune rule

in step 2, you metion about the situable prune rule, is there any example to show?
And what called suitable prune rule?

稀疏化训练问题

为什么我稀疏化训练的时候,BN层的gamma系数随着训练的进行,从1到-1震荡,然后又从-1到1震荡,而且所有的gamma系数都同步的震荡变化,根本不会出现离散稀疏的情况。不知道是否有人出现过这种情况?

Do we need to set the masks for the bias in BN layers?

请问pruned_model里有对BN的bias做剪枝吗,如果是直接只对weight做掩膜的话剩下的bias还是会对网络造成很大的影响的吧?这样子还是不可以直接赋值给Compact_model的,这个问题应该怎么解决呢?

剪枝的卷积层

请问您是只对特征提取层进行剪枝,还是也对yolo层的卷积层进行剪枝呢

求助:用自己数据集训练时出现了问题,是不支持png还是什么原因呢?

(pytorch) D:\zcy\YOLOv3-model-pruning>python train.py --model_def config/yolov3-light.cfg
Namespace(batch_size=16, checkpoint_interval=5, data_config='config/light.data', debug_file='debug', epochs=100, evaluation_interval=1, img_size=416, lr=0.001, model_def='config/yolov3-light.cfg', multiscale_training=False, n_cpu=4, pretrained_weights='weights/yolov3.weights', s=0.01, sr=False)
Traceback (most recent call last):
File "train.py", line 128, in
for batch_i, (_, imgs, targets) in enumerate(dataloader):
File "D:\ProgramFiles\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 582, in next
return self._process_next_batch(batch)
File "D:\ProgramFiles\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 608, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
FileNotFoundError: Traceback (most recent call last):
File "D:\ProgramFiles\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data_utils\worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "D:\ProgramFiles\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data_utils\worker.py", line 99, in
samples = collate_fn([dataset[i] for i in batch_indices])
File "D:\zcy\YOLOv3-model-pruning\utils\datasets.py", line 86, in getitem
img = Image.open(img_path).convert('RGB')
File "D:\ProgramFiles\anaconda3\envs\pytorch\lib\site-packages\PIL\Image.py", line 2770, in open
fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: 'data\images\train\dayClip8--00625.png'

Can you provide a visualize function?

When I try to plot images with predicted results, it shows:
50
(tip: blue represents gt box, red represents predict box)
Then I reference converter.py and datasets.py:

  • extract real gt box (x1, y1, x2, y2) from *.mat
  • calculate (x_ctr, y_ctr, width, height)
  • calculate normalized (x_ctr, y_ctr, width, height) and write it to *.txt
  • do padding on image -> (500, 500) pad pixel is 0
  • calculate padded (x1, y1, x2, y2)
  • calculate normalized (x_ctr, y_ctr, w, h) according to padded (x1, y1, x2, y2)
  • do augmentation
  • use ToTensor() pixel range (0, 255) -> (0, 1)
  • convert normalized (x_ctr, y_ctr, w, h) to tensor
  • get img_path, img, targets
    It is tedious, Can you provide a demo.py to visualize predict results?

pruning code

When open the key source code of pruning, the master branch can not train or test without pruning part ?

训练时使用多GPU

我加入Dataparallel后报错,缺少load_darknet_weights,我把这个函数加入Dataparallel后又接着报其他错。请问有没有成功使用多GPU训练的

训练得到的mAP低

我按照博主的方法一步步操作,但是训练baseline_model时在训练迭代3,4个epoch时达到过0.73左右,但是在迭代100个epoch后,mAP只有0.69左右。请问您是怎么训练的,我为什么会出现这种情况?

关于baseline model 0.76

您好,我训练了baseline model mAP最高到0.72左右,看您表格里的结果是0.76,请问您训练baseline model的时候是加入了一些trick嘛

invalid syntax: f"conv_{module_i}"

When I run python3 test.py --model_def config/prune_yolov3-hand.cfg --weights_path /home/yehao/Downloads/prune_yolov3_ckpt.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01.
It has the error:
48

执行test_prune.py进行剪枝,报错

报错内容

=====

Detecting objects: 100% 42/42 [01:02<00:00, 1.14s/it]
Computing AP: 100% 3/3 [00:00<00:00, 303.82it/s]
Threshold should be less than 0.3998.
The corresponding prune ratio is 0.769.
Channels with Gamma value less than 0.4403 are pruned!
Detecting objects: 100% 42/42 [01:01<00:00, 1.11s/it]
Traceback (most recent call last):
File "test_prune.py", line 78, in
threshold = prune_and_eval(model, sorted_bn, percent)
File "test_prune.py", line 69, in prune_and_eval
mAP = eval_model(model_copy)[2].mean()
File "test_prune.py", line 27, in
nms_thres=0.5, img_size=model.img_size, batch_size=24)
File "/YOLOv3-model-pruning/test.py", line 55, in evaluate
assert sample_metrics != []
AssertionError

=======
自己数据集classes=3,请指教

训练时model.py在load_darknet_weights处报错

您好,我进行训练的时候报以下错误

File "train.py", line 77, in
model.load_darknet_weights(opt.pretrained_weights)
File "/home/cct/YOLOv3-model-pruning/models.py", line 317, in load_darknet_weights
conv_w = torch.from_numpy(weights[ptr: ptr + num_w]).view_as(conv_layer.weight)
RuntimeError: shape '[256, 128, 3, 3]' is invalid for input of size 160590

请问这个怎么解决?

求助:稀疏化训练报错

Traceback (most recent call last):
File "train.py", line 129, in
for batch_i, (_, imgs, targets) in enumerate(dataloader):
File "C:\Users\86151\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 582, in next
return self._process_next_batch(batch)
File "C:\Users\86151\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 608, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
ValueError: Traceback (most recent call last):
File "C:\Users\86151\Anaconda3\lib\site-packages\torch\utils\data_utils\worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "C:\Users\86151\Anaconda3\lib\site-packages\torch\utils\data_utils\worker.py", line 99, in
samples = collate_fn([dataset[i] for i in batch_indices])
File "D:\BaiduNetdiskDownload\YOLOv3-model-pruning-master\YOLOv3-model-pruning-master\utils\datasets.py", line 128, in getitem
img, boxes = augment(img, boxes)
File "D:\BaiduNetdiskDownload\YOLOv3-model-pruning-master\YOLOv3-model-pruning-master\utils\augmentations.py", line 29, in augment
augmented = aug(image=image, bboxes=boxes_coord, category_id=labels)
File "C:\Users\86151\Anaconda3\lib\site-packages\albumentations\core\composition.py", line 189, in call
convert_bboxes_to_albumentations, data)
File "C:\Users\86151\Anaconda3\lib\site-packages\albumentations\core\composition.py", line 249, in data_preprocessing
data[data_name] = convert_fn(data[data_name], params['format'], rows, cols, check_validity=True)
File "C:\Users\86151\Anaconda3\lib\site-packages\albumentations\augmentations\bbox_utils.py", line 158, in convert_bboxes_to_albumentations
return [convert_bbox_to_albumentations(bbox, source_format, rows, cols, check_validity) for bbox in bboxes]
File "C:\Users\86151\Anaconda3\lib\site-packages\albumentations\augmentations\bbox_utils.py", line 158, in
return [convert_bbox_to_albumentations(bbox, source_format, rows, cols, check_validity) for bbox in bboxes]
File "C:\Users\86151\Anaconda3\lib\site-packages\albumentations\augmentations\bbox_utils.py", line 118, in convert_bbox_to_albumentations
check_bbox(bbox)
File "C:\Users\86151\Anaconda3\lib\site-packages\albumentations\augmentations\bbox_utils.py", line 184, in check_bbox
value=value,
ValueError: Expected x_max for bbox [0.9494795, 0.5703125, 1.0000005, 0.6145835000000001, 0.0] to be in the range [0.0, 1.0], got 1.0000005.

问题

想问一下残差结构这里你是怎么处理的?

No module named 'albumentations'

When I run the test.py code, the error occurs:

Traceback (most recent call last):
  File "test.py", line 5, in <module>
    from utils.datasets import *
  File "YOLOv3-model-pruning/utils/datasets.py", line 10, in <module>
    from utils.augmentations import augment
  File "YOLOv3-model-pruning/utils/augmentations.py", line 4, in <module>
    import albumentations as A
ModuleNotFoundError: No module named 'albumentations'

关于哪些卷积是可以裁剪的

你好!
我想问下关于哪些卷积是可以裁剪的问题,在我看来并不是所有conv2d都可以裁剪,对于yolov3中每个shortcut所关联的2个卷积(-1,-3)应该是不能裁剪的吧?(因为shortcut是相加裁剪后层数不对应)
我这边计算年下yolov3所有带bn的卷积加起来一共有26304层,其中12544层和shortcut有关联,无法裁剪,将近一半。
由于你的repo中没有给出prune.py,我不知道你是否裁剪了shortcut所关联的卷积,因为你说你裁剪了80%,不知道你是如何裁剪的,如何对付shortcut关联的这些卷积,是否裁剪了他们?

test_prune.py进行剪枝时出错,求助

RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

Detecting objects: 0%| | 0/37 [00:04<?, ?it/s]
Detecting objects: 0%| | 0/37 [00:00<?, ?it/s]

没找到该在哪里修改

assertionerror

在运行train.py的时候,执行到evaluate,得到assert sample_metrics 为空怎么办呢?为什么会有这个问题呢?
有时候 outputs = model(imgs)的outputs不为空,但是经过outputs = non_max_suppression(outputs, conf_thres=conf_thres, nms_thres=nms_thres)之后outputs就为空了。

关于剪枝模型的问题

您好,我是在darknet框架下按照论文的方法对yolov3进行模型剪枝的,但是对gamma系数剪枝完后再训练微调后模型参数量又变回原来的大小,原因是cfg文件设置的问题吗。然后对于您给出的剪枝后模型能给出个简单的图片边框预测的程序吗。

求助:进行稀疏化训练报错

(pytorch-yolov3) C:\YOLOv3-model-pruning-master>python train.py --model_def config/yolov3-hand.cfg
Namespace(batch_size=16, checkpoint_interval=5, data_config='config/oxfordhand.data', debug_file='debug', epochs=100, evaluation_interval=1, img_size=416, lr=0.001, model_def='config/yolov3-hand.cfg', multiscale_training=False, n_cpu=4, pretrained_weights='weights/yolov3.weights', s=0.01, sr=False)
Traceback (most recent call last):
File "C:\Users\923\Anaconda3\envs\pytorch-yolov3\lib\site-packages\tensorboardX\record_writer.py", line 40, in directory_check
factory = REGISTERED_FACTORIES[prefix]
KeyError: 'logs\0821-01'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "train.py", line 50, in
logger = Logger("logs")
File "C:\YOLOv3-model-pruning-master\utils\logger.py", line 27, in init
self.writer = SummaryWriter(os.path.join(log_dir, timestamp))
File "C:\Users\923\Anaconda3\envs\pytorch-yolov3\lib\site-packages\tensorboardX\writer.py", line 257, in init
self._get_file_writer()
File "C:\Users\923\Anaconda3\envs\pytorch-yolov3\lib\site-packages\tensorboardX\writer.py", line 321, in _get_file_writer
**self.kwargs)
File "C:\Users\923\Anaconda3\envs\pytorch-yolov3\lib\site-packages\tensorboardX\writer.py", line 93, in init
logdir, max_queue, flush_secs, filename_suffix)
File "C:\Users\923\Anaconda3\envs\pytorch-yolov3\lib\site-packages\tensorboardX\event_file_writer.py", line 104, in init
directory_check(self._logdir)
File "C:\Users\923\Anaconda3\envs\pytorch-yolov3\lib\site-packages\tensorboardX\record_writer.py", line 44, in directory_check
os.makedirs(path)
File "C:\Users\923\Anaconda3\envs\pytorch-yolov3\lib\os.py", line 220, in makedirs
mkdir(name, mode)
NotADirectoryError: [WinError 267] 目录名称无效。: 'logs\0821-01:30'

有没有在线测试检测手的代码

我尝试写了在线检测手的代码,结果有误,想问下你有没有写?我想接上摄像头,在线测检测手并且画框,给出手画框的位置

导入错误

ModuleNotFoundError: No module named 'utils.prune_utils'
utils 文件夹下面没有prune_utils文件

How opt.alpha works?

Hi Lam,
Thank you for the great works.
I haven't found much introductions for alpha. And I have no idea how it works.

requirements:

  1. numpy>=1.13
  2. tensorboardX
    git clone https://github.com/lanpa/tensorboardX && cd tensorboardX && python setup.py install
  3. albumentations
conda install -c conda-forge imgaug
conda install albumentations -c albumentations
  1. terminaltables
    pip install terminaltables
  2. tqdm
  3. torch
  4. random
  5. matplotlib
    9.......

关于剪枝方式的问题

你好,首先很感谢你对于yolo3的剪枝工作。
我想询问一下你对于darknet的卷积层剪枝是直接将需要剪枝的卷积层的权重设为0,还是将该层卷积层从网络中剔除?
谢谢!

数据集问题

Hi, thank you in advance.
I am confused about Oxford hand dataset, It is split into train, val, test dataset, but train dataset includes all test images. Is it normal?

non_max_suppression耗时过长,最后死循环

我自己补充了一下缺失的prune_util.py来训练,但是发现non_max_suppression计算的特别慢,计算到99%的时候死循环,也就是detections.size(0)一直保持不变。请问您有遇到过这个问题么?
eriklindernoren工程下有人讨论过这个问题,但是没有什么解决办法。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.