Coder Social home page Coder Social logo

wx-b / w-stereo-disp Goto Github PK

View Code? Open in Web Editor NEW

This project forked from div99/w-stereo-disp

0.0 0.0 0.0 36.57 MB

(NeurIPS 2020) Wasserstein Distances for Stereo Disparity Estimation

Home Page: https://div99.github.io/W-Stereo-Disp/

License: MIT License

Python 100.00%

w-stereo-disp's Introduction

Wasserstein Distances for Stereo Disparity Estimation

Accepted in NeurIPS 2020 as Spotlight.

Wasserstein Distances for Stereo Disparity Estimation

by Divyansh Garg, Yan Wang, Bharath Hariharan, Mark Campbell, Kilian Q. Weinberger and Wei-Lun Chao

Figure

Citation

@inproceedings{div2020wstereo,
  title={Wasserstein Distances for Stereo Disparity Estimation},
  author={Garg, Divyansh and Wang, Yan and Hariharan, Bharath and Campbell, Mark and Weinberger, Kilian and Chao, Wei-Lun},
  booktitle={NeurIPS},
  year={2020}
}

Introduction

Existing approaches to depth or disparity estimation output a distribution over a set of pre-defined discrete values. This leads to inaccurate results when the true depth or disparity does not match any of these values. The fact that this distribution is usually learned indirectly through a regression loss causes further problems in ambiguous regions around object boundaries. We address these issues using a new neural network architecture that is capable of outputting arbitrary depth values, and a new loss function that is derived from the Wasserstein distance between the true and the predicted distributions. We validate our approach on a variety of tasks, including stereo disparity and depth estimation, and the downstream 3D object detection. Our approach drastically reduces the error in ambiguous regions, especially around object boundaries that greatly affect the localization of objects in 3D, achieving the state-of-the-art in 3D object detection for autonomous driving.

Contents

Our Wasserstein loss modification W_loss can be easily plugged in existing stereo depth models to improve the training and obtain better results.

We release the code for CDN-PSMNet and CDN-SDN models.

Requirements

  1. Python 3.7
  2. Pytorch 1.2.0+
  3. CUDA
  4. pip install -r ./requirements.txt
  5. SceneFlow
  6. KITTI

Pretrained Models

TO BE ADDED.

Datasets

You have to download the SceneFlow and KITTI datasets. The structures of the datasets are shown in below.

SceneFlow Dataset Structure

SceneFlow
    | monkaa
        | frames_cleanpass
        | disparity
    | driving
        | frames_cleanpass
        | disparity
    | flyingthings3d
        | frames_cleanpass 
        | disparity

KITTI Object Detection Dataset Structure

KITTI
    | training
        | calib
        | image_2
        | image_3
        | velodyne
    | testing
        | calib
        | image_2
        | image_3

Generate soft-links of SceneFlow Datasets. The results will be saved in ./sceneflow folder. Please change to fakepath path-to-SceneFlow to the SceneFlow dataset location before running the script.

python sceneflow.py --path path-to-SceneFlow --force

Convert the KITTI velodyne ground truths to depth maps. Please change to fakepath path-to-KITTI to the SceneFlow dataset location before running the script.

python ./src/preprocess/generate_depth_map.py --data_path path-to-KITTI/ --split_file ./split/trainval.txt

Optionally download KITTI2015 datasets for evaluating stereo disparity models.

Training and Inference

We have provided all pretrained models Pretrained Models. If you only want to generate the predictions, you can directly go to step 3.

The default setting requires four gpus to train. You can use smaller batch sizes which are btrain and bval, if you don't have enough gpus.

We provide code for both stereo disparity and stereo depth models.

1 Train CDN-SDN from Scratch on SceneFlow Dataset

python ./src/main_depth.py -c src/configs/sceneflow_w1.config

The checkpoints are saved in ./results/stack_sceneflow_w1/.

Follow same procedure to train stereo disparity model, but use src/main_disp.py and change to a disparity config.

2 Train CDN-SDN on KITTI Dataset

python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --pretrain ./results/sceneflow_w1/checkpoint.pth.tar --dataset  path-to-KITTI/training/

Before running, please change the fakepath path-to-KITTI/ to the correct one. --pretrain is the path to the pretrained model on SceneFlow. The training results are saved in ./results/kitti_w1_train.

If you are working on evaluating CDN on KITTI testing set, you might want to train CDN on training+validation sets. The training results will be saved in ./results/sdn_kitti_trainval.

python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --pretrain ./results/sceneflow_w1/checkpoint.pth.tar \
    --dataset  path-to-KITTI/training/ --split_train ./split/trainval.txt \
    --save_path ./results/sdn_kitti_trainval

The disparity models can also be trained on KITTI2015 datasets using `src/kitti2015_w1_disp.config`.

#### 3 Generate Predictions
Please change the fakepath `path-to-KITTI`. Moreover, if you use the our provided checkpoint, please modify the value of `--resume` to the checkpoint location. 

* a. Using the model trained on KITTI training set, and generating predictions on training + validation sets.
```bash
python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --resume ./results/sdn_kitti_train/checkpoint.pth.tar --datapath  path-to-KITTI/training/ \
    --data_list ./split/trainval.txt --generate_depth_map --data_tag trainval

The results will be saved in ./results/sdn_kitti_train/depth_maps_trainval/.

  • b. Using the model trained on KITTI training + validation set, and generating predictions on testing sets. You will use them when you want to submit your results to the leaderboard.

The results will be saved in ./results/sdn_kitti_trainval_set/depth_maps_trainval/.

# testing sets
python ./src/main.py -c src/configs/kitti_w1.config \
    --resume ./results/sdn_kitti_trainval/checkpoint.pth.tar --datapath  path-to-KITTI/testing/ \
    --data_list=./split/test.txt --generate_depth_map --data_tag test

The results will be saved in ./results/sdn_kitti_trainval/depth_maps_test/.

7 Train 3D Detection with Pseudo-LiDAR

For training 3D object detection models, follow step 4 and after in the Pseudo-LiDAR_V2 repo https://github.com/mileyan/Pseudo_Lidar_V2.

Results

Results on the Stereo Disparity

Figure

3D Object Detection Results on KITTI leader board

Figure

Questions

Please feel free to email us if you have any questions.

Divyansh Garg [email protected] Yan Wang [email protected] Wei-Lun Chao [email protected]

w-stereo-disp's People

Contributors

div99 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.