Coder Social home page Coder Social logo

hslol / rotation-retinanet-pytorch Goto Github PK

View Code? Open in Web Editor NEW
61.0 1.0 7.0 2.94 MB

A Rotation RetinaNet Pytorch Implementation on HRSC and SSDD Dataset.

Python 94.92% Makefile 0.03% Cython 5.04%
rotation rotation-detection retinanet object-detection sensing-images

rotation-retinanet-pytorch's Introduction

🚀RetinaNet Oriented Detector Based PyTorch

This is an oriented detector Rotation-RetinaNet implementation on Optical and SAR ship dataset.

🌟Performance of the implemented Rotation-RetinaNet Detector

Detection Performance on HRSC Dataset.

Detection Performance on SSDD Dataset.

🎯Experiment

Dataset Backbone Input Size bs Trick mAP.5 Config
SSDD ResNet-50 512 x 512 16 N 78.96 config file
SSDD ResNet-50 512 x 512 16 Augment 85.6 config file
HRSC ResNet-50 512 x 512 16 N 70.71 config file
HRSC ResNet-50 512 x 512 4 N 74.22 config file
HRSC ResNet-50 512 x 512 16 Augment 80.20 config file

💥Get Started

Installation

A. Install requirements:

conda create -n rotate python=3.7
conda activate rotate
conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=11.0 -c pytorch
pip install -r requirements.txt

Note: the opencv version must > 4.5.1

B. Install rotation_nms and rotation_overlaps module:

Only need one Step:
make

Demo

A. Set project's data path

you should set project's data path in yml file first.

# .yml file
# Note: all the path should be absolute path.
data_path = r'/$ROOT_PATH/SSDD_data/'  # absolute data root path
output_path = r'/$ROOT_PATH/Output/'  # absolute model output path

# For example
$ROOT_PATH
    -HRSC/
        -train/  # train set
	   -Annotations/
	      -*.xml
	   -images/
	      -*.jpg
	-test/  # test set
	   -Annotations/
	      -*.xml
           -images/
	      -*.jpg
	-ground-truth/
	   -*.txt  # gt label in txt format (for voc evaluation method)

    -SSDD/
       -train/  # train set
	  -Annotations/
	     -*.xml
	  -images/
	     -*.jpg
       -test/  # test set
	  -Annotations/
	     -*.xml
          -images/
	     -*.jpg
       -ground-truth/
	  -*.txt  # gt label in txt format (for voc evaluation method)
	  

    -Output/
        -checkpoints/
	    - the path of saving chkpt files
	-tensorboard/
	   - the path of saving tensorboard event files
	-evaluate/
	    - the path of saving model detection results for evaluate (voc method method)
	-log.log (save the loss and eval result)
	-yml file (config file)

B. Run the show.py

# for SSDD dataset
python show.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --chkpt {chkpt.file} --result_path show_result/RSSDD --pic_name demo1.jpg

# for HRSC dataset
python show.py --config_file ./configs/retinanet_r50_fpn_hrsc.yml --chkpt {chkpt.file} --result_path show_result/HRSC --pic_name demo1.jpg

Train

A. Prepare dataset

you should structure your dataset files as shown above.

B. Manual set project's hyper parameters

you should manual set projcet's hyper parameters in config file.

1. data file structure (Must Be Set !)
   has shown above.

2. Other settings (Optional)
   if you want to follow my experiment, dont't change anything.

C. Train Rotation-RetinaNet on SSDD or HRSC dataset with resnet-50 from scratch

C.1 Download the pre-trained resnet-50 pth file

you should download the pre-trained resnet-50 pth first and put the pth file in resnet_pretrained_pth/ folder.

C.2 Train Rotation-RetinaNet Detector on SSDD or HRSC Dataset with pre-trained pth file

# train model on SSDD dataset from scratch
python train.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --resume None

# train model on HRSC dataset from scratch
python train.py --config_file ./configs/retinanet_r50_hrsc.yml --resume None

D. Resume training Rotation-RetinaNet detector on SSDD or HRSC dataset

# train model on SSDD dataset from specific epoch
python train.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --resume {epoch}_{step}.pth

# train model on HRSC dataset from specific epoch
python train.py --config_file ./configs/retinanet_r50_hrsc.yml --resume {epoch}_{step}.pth

Evaluation

A. evaluate model performance on SSDD or HRSC val set.

python eval.py --Dataset SSDD --config_file ./configs/retinanet_r50_fpn_ssdd.yml --evaluate True --chkpt {epoch}_{step}.pth
python eval.py --Dataset HRSC --config_file ./configs/retinanet_r50_fpn_hrsc.yml --evaluate True --chkpt {epoch}_{step}.pth

💡Inferences

Thanks for these great work.
https://github.com/open-mmlab/mmrotate
https://github.com/ming71/Rotated-RetinaNet

⏩Zhihu Link

zhihu article

rotation-retinanet-pytorch's People

Contributors

hslol avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

rotation-retinanet-pytorch's Issues

cpu运行问题

你好博主,可以在cpu上训练吗,如何调整,如果可以的话能否增添更新代码

make 文件问题

博主你好,请问执行完make指令后,运行训练脚本仍然提示ModuleNotFoundError: No module named 'utils.rotation_overlaps.rbox_overlaps'该如何解决呢。

ssdd 精度问题

博主你好,我在mmrotate上实验ssdd,发现map只能到80,而我看你的实验可以到84以上,想问一下你的实验和mmrotate里面的retina有什么区别吗

horizontal_overlaps计算是否存在问题?

您好,看了您的代码,非常适合初学者读。这里有一个问题请教,网络产生的anchor是 [xc, yc, w, h, angle]的形式,由于gt也是 [xc, yc, w, h, angle]的形式,所以在计算horizontal_overlaps 的时候您将gt通过obb2hbb_oc转成了hbb的形式,但是anchor并没有转啊,这怎么求horizontal_overlaps的呢?

horizontal_overlaps = bbox_overlaps(
anchor.clone(), # generate anchor data copy
obb2hbb_oc(bbox_annotation[:, :-1]))

我理解应该是
horizontal_overlaps = bbox_overlaps(
obb2hbb_oc(anchor.clone()), # generate anchor data copy
obb2hbb_oc(bbox_annotation[:, :-1]))
请问我的理解是否存在偏差?

对poly2obb_np_oc注释的疑问

您好,感谢您的代码.我在自己的机器上跑通了您的代码,并且效果很好.学习您的代码过程中,我有一个小问题.
您的代码中poly2obb_np_oc是这样写的


def poly2obb_np_oc(poly):
    """ Modified !!
    Convert polygons to oriented bounding boxes.
    Args:
        polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
    Returns:
        obbs (ndarray): [x_ctr,y_ctr,w,h,angle] modified -> [x_ctr, y_ctr, h, w, angle(radian)]
    """
    bboxps = np.array(poly).reshape((4, 2))
    rbbox = cv2.minAreaRect(bboxps)
    x, y, h, w, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2]
    # assert 0 < a <= 90, f'error from poly2obb_np_oc function.'
    if w < 2 or h < 2:
        return
    while not 0 < a <= 90:
        if a == -90:
            a += 180
        else:
            a += 90
            w, h = h, w
    a = a / 180 * np.pi
    assert 0 < a <= np.pi / 2
    return x, y, h, w, a

在cv2.__version__为4.5.5的环境中,我进行了测试.我认为
x, y, h, w, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2]
应该为
x, y, w, h, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2]
即w应该为rbbox[1][0],h应该为rbbox[1][1].并且,我在mmrotate上找到同样的函数


def poly2obb_np_oc(poly):
    """Convert polygons to oriented bounding boxes.
    Args:
        polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
    Returns:
        obbs (ndarray): [x_ctr,y_ctr,w,h,angle]
    """
    bboxps = np.array(poly).reshape((4, 2))
    rbbox = cv2.minAreaRect(bboxps)
    x, y, w, h, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[
        2]
    if w < 2 or h < 2:
        return
    while not 0 < a <= 90:
        if a == -90:
            a += 180
        else:
            a += 90
            w, h = h, w
    a = a / 180 * np.pi
    assert 0 < a <= np.pi / 2
    return x, y, w, h, a

mmrotate该函数实现
所以我很困惑您注释中的
obbs (ndarray): [x_ctr,y_ctr,w,h,angle] modified -> [x_ctr, y_ctr, h, w, angle(radian)].

How to predict multi-scale outputs?

Really appreciate your work. In your code, RetinaNet only predict a single scale bbox, so how to design a multi-scale output (like YOLO series)? I am actually concerned about how to code about loss and decoder in models/model.py.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.