Coder Social home page Coder Social logo

wjl198435 / centernet-tensorrt Goto Github PK

View Code? Open in Web Editor NEW

This project forked from josephchenhub/centernet-tensorrt

0.0 1.0 0.0 5.83 MB

This is a C++ implementation of CenterNet using TensorRT and CUDA

License: MIT License

CMake 2.11% C++ 76.82% C 1.00% Cuda 14.03% Dockerfile 0.34% SWIG 0.18% Python 5.52% Shell 0.01%

centernet-tensorrt's Introduction

This is a C++ implementation of CenterNet using TensorRT and CUDA. Thanks for the official implementation of CenterNet (Objects as Points)!

Dependencies:

  • Ubuntu 16.04
  • PyTorch 1.2.0 (for the compatibility of TensorRT 5 in Jetson Tx2)
  • CUDA 10.0 [required]
  • TensorRT-7.0.0.11 (for CUDA10.0) [required]
  • CUDNN (for CUDA10.0, may not be used) [required]
  • libtorch (torch c++ lib of cpu version, gpu version may conflict with the environment) [optional]
  • gtest (Google C++ testing framework) [optional]

Plugins of TensorRT:

  • MyUpsampling: F.interpolate/ nn.nn.UpsamplingBilinear2d
  • DCN: deformable CNN

PyTorch to onnx

Clone the repo CenterNet (Objects as Points) and download the models, then modify the backbone's outputs from

return [ret]

to

if self.training:                                                                                                           
    return [ret]                                                                                                             
else:                                                                                                                       
    hm = ret['hm'].sigmoid_()                                                                                               
    hmax = nn.functional.max_pool2d(hm, (3, 3), stride=1, padding=1)                                                         
    keep = (hmax == hm).float()                                                                                             
    hm = hm * keep                                                                                                                   
    if len(self.heads) == 3: # 2D object detection                                                                           
        return hm, ret['wh'], ret['reg']                                                                                              
    elif len(self.heads) == 6: # multi_pose                                                                                 
        wh, reg, hm_hp, hps, hp_offset = ret['wh'], ret['reg'], ret['hm_hp'], ret['hps'], ret['hp_offset']                            
        hm_hp = hm_hp.sigmoid_()                                                                                             
        hm_hp_max = nn.functional.max_pool2d(hm_hp, (3, 3), stride=1, padding=1)                                            
        keep = (hm_hp_max == hm_hp).float()                                                                                
        hm_hp = hm_hp * keep                                                                                                          
        return hm, wh, reg, hps, hm_hp, hp_offset                                                                            
    else:                                                                                                                   
        #TODO                                                                                                               
        raise Exception("Not implemented!")  

For 2D object detection, modify the function process in src/lib/detectors/ctdet.py:

with torch.no_grad():
    hm, wh, reg = self.model(images)

    torch.onnx.export(self.model, images, "ctdet-resdcn18.onnx", opset_version=9, verbose=False, output_names=["hm", "wh", "reg"])
    quit()

For human pose estimation, modify the function process in src/lib/detectors/multi_pose.py:

       hm, wh, reg, hps, hm_hp, hp_offset = self.model(images)                                                               
       names=['hm', 'wh', 'reg', 'hps', 'hm_hp', 'hp_offset']                                                               
       torch.onnx.export(self.model, images, "pose.onnx", opset_version=9, \                                                 
                         verbose=False, input_names=["input"], output_names=names)  

and replace the CenterNet/src/lib/models/networks/DCNv2 with DCNv2.

To obtain the onnx file, run the command:

 cd CenterNet/src &&\
 python3 demo.py ctdet --arch resdcn_18 --demo xxxxx.jpg --load_model ../models/ctdet_coco_resdcn18.pth --debug 4 --exp_id 1

Build & Run:

  1. build the plugins of TensorRT:
cd onnx-tensorrt/plugin/build &&\
cmake .. &&\
make -j

you may need to explicitly specifiy the path of some libraries. To varify the correctness of plugins, set Debug mode and build with GTEST in plugin/CMakeLists.txt.

  1. build the onnx-tensorrt with this command:
cd onnx-tensorrt/build &&\
cmake .. &\
make -j

After successfully building the tool, we can convert the xxx.onnx file to serialized TensorRT engine xxxx.trt:

cd onnx-tensorrt &&\
./build/onnx2trt ctdet-resdcn18.onnx -d 16 -o ~/ctdet-resdcn18-fp16.trt
  1. build the inference code:
cd centernet-tensorrt/build &&\
cmake .. &&\
make -j

then, run this command to see the detection's result:

./build/ctdet_infer ~/ctdet-resdcn18-fp16.trt ./data/xxxx.jpg

For pose estimation, run the command:

./build/pose_infer xxxxx.trt xxxx.jpg

Analysis

  1. inference speed:

#TODO

centernet-tensorrt's People

Contributors

josephchenhub avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.