Coder Social home page Coder Social logo

ml-lab / pytorch-retinanet-1 Goto Github PK

View Code? Open in Web Editor NEW

This project forked from yhenon/pytorch-retinanet

0.0 2.0 0.0 1009 KB

Pytorch implementation of RetinaNet object detection.

License: Apache License 2.0

Python 89.04% Shell 0.60% Cuda 3.70% C 6.47% C++ 0.19%

pytorch-retinanet-1's Introduction

pytorch-retinanet

img3 img5

Pytorch implementation of RetinaNet object detection as described in Focal Loss for Dense Object Detection by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Dollár.

This implementation is primarily designed to be easy to read and simple to modify.

Results

Currently, this repo achieves 33.7% mAP at 600px resolution with a Resnet-50 backbone. The published result is 34.0% mAP. The difference is likely due to the use of Adam optimizer instead of SGD with weight decay.

Installation

  1. Clone this repo

  2. Install the required packages:

apt-get install tk-dev python-tk
  1. Install the python packages:
pip install cffi

pip install pandas

pip install pycocotools

pip install cython

pip install opencv-python

pip install requests

  1. Build the NMS extension.
cd pytorch-retinanet/lib
bash build.sh
cd ../

Note that you may have to edit line 14 of build.sh if you want to change which version of python you are building the extension for.

Training

The network can be trained using the train.py script. Currently, two dataloaders are available: COCO and CSV. For training on coco, use

python train.py --dataset coco --coco_path ../coco --depth 50

For training using a custom dataset, with annotations in CSV format (see below), use

python train.py --dataset csv --csv_train <path/to/train_annots.csv>  --csv_classes <path/to/train/class_list.csv>  --csv_val <path/to/val_annots.csv>

Note that the --csv_val argument is optional, in which case no validation will be performed.

Pre-trained model

A pre-trained model is available at:

The state dict model can be loaded using:

retinanet = model.resnet50(num_classes=dataset_train.num_classes(),)
retinanet.load_state_dict(torch.load(PATH_TO_WEIGHTS))

The pytorch model can be loaded directly using:

retinanet = torch.load(PATH_TO_MODEL)

Visualization

To visualize the network detection, use visualize.py:

python visualize.py --dataset coco --coco_path ../coco --model <path/to/model.pt>

This will visualize bounding boxes on the validation set. To visualise with a CSV dataset, use:

python visualize.py --dataset csv --csv_classes <path/to/train/class_list.csv>  --csv_val <path/to/val_annots.csv> --model <path/to/model.pt>

Model

The retinanet model uses a resnet backbone. You can set the depth of the resnet model using the --depth argument. Depth must be one of 18, 34, 50, 101 or 152. Note that deeper models are more accurate but are slower and use more memory.

CSV datasets

The CSVGenerator provides an easy way to define your own datasets. It uses two CSV files: one file containing annotations and one file containing a class name to ID mapping.

Annotations format

The CSV file with annotations should contain one annotation per line. Images with multiple bounding boxes should use one row per bounding box. Note that indexing for pixel values starts at 0. The expected format of each line is:

path/to/image.jpg,x1,y1,x2,y2,class_name

Some images may not contain any labeled objects. To add these images to the dataset as negative examples, add an annotation where x1, y1, x2, y2 and class_name are all empty:

path/to/image.jpg,,,,,

A full example:

/data/imgs/img_001.jpg,837,346,981,456,cow
/data/imgs/img_002.jpg,215,312,279,391,cat
/data/imgs/img_002.jpg,22,5,89,84,bird
/data/imgs/img_003.jpg,,,,,

This defines a dataset with 3 images. img_001.jpg contains a cow. img_002.jpg contains a cat and a bird. img_003.jpg contains no interesting objects/animals.

Class mapping format

The class name to ID mapping file should contain one mapping per line. Each line should use the following format:

class_name,id

Indexing for classes starts at 0. Do not include a background class as it is implicit.

For example:

cow,0
cat,1
bird,2

Acknowledgements

Examples

img1 img2 img4 img6 img7 img8

pytorch-retinanet-1's People

Contributors

mimoralea avatar yhenon avatar yhenon-nextdroid avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.