Coder Social home page Coder Social logo

mxbonn / anchor_pruning Goto Github PK

View Code? Open in Web Editor NEW
4.0 2.0 2.0 52 KB

Official code of the paper "Anchor pruning for object detection"

Home Page: https://doi.org/10.1016/j.cviu.2022.103445

Jupyter Notebook 12.82% Python 87.18%
mmdetection object-detection

anchor_pruning's Introduction

Anchor pruning for object detection [CVIU 2022] [arXiv]

By Maxim Bonnaerens, Matthias Freiberger and Joni Dambre.

Abstract

This paper proposes anchor pruning for object detection in one-stage anchor-based detectors. While pruning techniques are widely used to reduce the computational cost of convolutional neural networks, they tend to focus on optimizing the backbone networks where often most computations are. In this work we demonstrate an additional pruning technique, specifically for object detection: anchor pruning. With more efficient backbone networks and a growing trend of deploying object detectors on embedded systems where post-processing steps such as non-maximum suppression can be a bottleneck, the impact of the anchors used in the detection head is becoming increasingly more important. In this work, we show that many anchors in the object detection head can be removed without any loss in accuracy. With additional retraining, anchor pruning can even lead to improved accuracy. Extensive experiments on SSD and MS COCO show that the detection head can be made up to 44% more efficient while simultaneously increasing accuracy. Further experiments on RetinaNet and PASCAL VOC show the general effectiveness of our approach. We also introduce ` overanchorized' models that can be used together with anchor pruning to eliminate hyperparameters related to the initial shape of anchors.

Citation

@article{bonnaerens2022anchor,
  title={Anchor pruning for object detection},
  author={M. Bonnaerens, M. Freiberger and J. Dambre},
  journal={Computer Vision and Image Understanding},
  pages={103445},
  year={2022},
  publisher={Elsevier},
  doi = {https://doi.org/10.1016/j.cviu.2022.103445},
}

Results and models of SSD

Anchor Configuration AP .50:.95 FLOPS head BBoxes Config Download
SSD Baseline 25.6 4231M 8732 config model
SSD Configuration-A retrained 25.4 3607M 7814 config model
SSD Configuration-B retrained 25.6 2476M 4926 config model
SSD Configuration-C retrained 25.2 1628M 3121 config model
SSDConfiguration-D retrained 22.8 774M 1291 config model
RetinaNet Baseline 36.5 129B config model
RetinaNet Pruned 34.8 31B config model

Above results are on the COCO validation set while the results in the paper are on the COCO test set.

Results plot from paper

Installation

This repository builds upon MMDetection.

See The MMDetection documentation for installation instructions. Last confirmed working version is mmdet v2.25.0 with mmcv-full v1.4.8

pip install openmim
mim install mmcv-full==1.4.8
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
git switch --detach v2.25.0
pip install -v -e .

Next, clone our repository and install the anchor pruning package

git clone https://github.com/Mxbonn/anchor_pruning.git
cd anchor_pruning
pip install -e .

Getting started.

Please see Tutorial.ipynb for a general guide on how to do anchor pruning.

To run the given pretrained models above run

python tools/mmdet_test.py configs/ssd/configuration_B.py pretrained_models/configuration_B.pth --eval bbox

after modifying the paths to the mmdet base config in configuration_X.py and linking your dataset directory to data/ similarly as required for mmdetection.

anchor_pruning's People

Contributors

mxbonn avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

anchor_pruning's Issues

RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

Traceback (most recent call last):
File "tools/mmdet_train.py", line 239, in
main()
File "tools/mmdet_train.py", line 235, in main
meta=meta)
File "/home/featurize/mmdetection-2.25.0/mmdet/apis/train.py", line 244, in train_detector
runner.run(data_loaders, cfg.workflow)
File "/environment/miniconda3/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
epoch_runner(data_loaders[i], **kwargs)
File "/environment/miniconda3/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
self.run_iter(data_batch, train_mode=True, **kwargs)
File "/environment/miniconda3/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter
**kwargs)
File "/environment/miniconda3/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step
return self.module.train_step(*inputs[0], **kwargs[0])
File "/home/featurize/mmdetection-2.25.0/mmdet/models/detectors/base.py", line 248, in train_step
losses = self(**data)
File "/environment/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/environment/miniconda3/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func
return old_func(*args, **kwargs)
File "/home/featurize/mmdetection-2.25.0/mmdet/models/detectors/base.py", line 172, in forward
return self.forward_train(img, img_metas, **kwargs)
File "/home/featurize/mmdetection-2.25.0/mmdet/models/detectors/single_stage.py", line 84, in forward_train
gt_labels, gt_bboxes_ignore)
File "/home/featurize/mmdetection-2.25.0/mmdet/models/dense_heads/base_dense_head.py", line 335, in forward_train
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
File "/environment/miniconda3/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 186, in new_func
return old_func(*args, **kwargs)
File "/home/featurize/mmdetection-2.25.0/mmdet/models/dense_heads/anchor_head.py", line 519, in loss
num_total_samples=num_total_samples)
File "/home/featurize/mmdetection-2.25.0/mmdet/core/utils/misc.py", line 30, in multi_apply
return tuple(map(list, zip(*map_results)))
File "/home/featurize/mmdetection-2.25.0/mmdet/models/dense_heads/anchor_head.py", line 434, in loss_single
cls_score, labels, label_weights, avg_factor=num_total_samples)
File "/environment/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/featurize/mmdetection-2.25.0/mmdet/models/losses/focal_loss.py", line 240, in forward
avg_factor=avg_factor)
File "/home/featurize/mmdetection-2.25.0/mmdet/models/losses/focal_loss.py", line 140, in sigmoid_focal_loss
alpha, None, 'none')
File "/environment/miniconda3/lib/python3.7/site-packages/mmcv/ops/focal_loss.py", line 56, in forward
input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.