Coder Social home page Coder Social logo

ctg's Introduction

Controllable Traffic Generation (CTG)

Codebase of Controllable Traffic Generation (CTG) and Controllable Traffic Generation Plus Plus (CTG++).

This repo is mostly built on top of traffic-behavior-simulation (tbsim). The diffusion model part is built on top of initial implementation in Diffuser. It also lightly uses STLCG.

Installation

Basic (mainly based on tbsim)

Create conda environment (Note nuplan-devkit needs python>=3.9 so the virtual environment with python version 3.9 needs to be created instead of python 3.8.)

conda create -n bg3.9 python=3.9
conda activate bg3.9

Install CTG (this repo)

git clone https://github.com/NVlabs/CTG.git
cd CTG
pip install -e .

Install a customized version of trajdata

cd ..
git clone https://github.com/AIasd/trajdata.git
cd trajdata
pip install -r trajdata_requirements.txt
pip install -e .

Install Pplan

cd ..
git clone https://github.com/NVlabs/spline-planner.git Pplan
cd Pplan
pip install -e .

Potential Issue

One might need to run the following:

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 torchmetrics==0.11.1 torchtext --extra-index-url https://download.pytorch.org/whl/cu113

Extra Setup for STL (CTG)

Install STLCG and switch to dev branch

cd ..
git clone https://github.com/StanfordASL/stlcg.git
cd stlcg
pip install graphviz
pip install -e .
git checkout dev

Extra Setup for ChatGPT API (if using lanugage interface of CTG++)

pip install openai
pip install tiktoken

Create a file openai_key.py and put your openai key in it with the variable name openai_key.

Quick start

1. Obtain dataset(s)

We currently support the nuScenes dataset.

nuScenes

  • Download the nuScenes dataset (with the v1.3 map extension pack) and organize the dataset directory as follows:
    nuscenes/
    │   maps/
    │   v1.0-mini/
    │   v1.0-trainval/
    

2. Train a diffuser model

nuScenes dataset (Note: remove --debug flag when doing the actual training and to support wandb logging):

python scripts/train.py --dataset_path <path-to-nuscenes-data-directory> --config_name trajdata_nusc_diff --debug

A concrete example (CTG):

python scripts/train.py --dataset_path ../behavior-generation-dataset/nuscenes --config_name trajdata_nusc_diff --debug

A concrete example (CTG++):

python scripts/train.py --dataset_path ../behavior-generation-dataset/nuscenes --config_name trajdata_nusc_scene_diff --debug

3. Run rollout of a trained model (closed-loop simulation)

Run Rollout

python scripts/scene_editor.py \
  --results_root_dir nusc_results/ \
  --num_scenes_per_batch 1 \
  --dataset_path <path-to-nuscenes-data-directory> \
  --env trajdata \
  --policy_ckpt_dir <path-to-checkpoint-dir> \
  --policy_ckpt_key <ckpt-file-identifier> \
  --eval_class <class-of-model-to-rollout> \
  --editing_source 'config' 'heuristic' \
  --registered_name 'trajdata_nusc_diff' \
  --render

The following is a concrete example for running CTG (when using the pre-trained model):

python scripts/scene_editor.py \
  --results_root_dir nusc_results/ \
  --num_scenes_per_batch 1 \
  --dataset_path ../behavior-generation-dataset/nuscenes \
  --env trajdata \
  --policy_ckpt_dir ../../summer_project/behavior-generation/trained_models_only_new/trajdata_nusc/ctg_original \
  --policy_ckpt_key iter70000.ckpt \
  --eval_class Diffuser \
  --editing_source 'config' 'heuristic' \
  --registered_name 'trajdata_nusc_diff' \
  --render

The following is a concrete example for running CTG++ (when using the pre-trained model):

python scripts/scene_editor.py \
  --results_root_dir nusc_results/ \
  --num_scenes_per_batch 1 \
  --dataset_path ../behavior-generation-dataset/nuscenes \
  --env trajdata \
  --policy_ckpt_dir ../../summer_project/behavior-generation/trained_models_only_new/trajdata_nusc/ctg++8_9,10edge \
  --policy_ckpt_key iter50000.ckpt \
  --eval_class SceneDiffuser \
  --editing_source 'config' 'heuristic' \
  --registered_name 'trajdata_nusc_scene_diff' \
  --render

4.Parse Results for rollout

python scripts/parse_scene_edit_results.py --results_dir
 <rollout_results_dir> --estimate_dist

Pre-trained models

We have provided checkpoints for models of CTG and CTG++ here. Note that the provided CTG model slightly differ from that in the original CTG paper. The main difference is that the prediction horizon is 52 rather than 20. The pre-trained models are provided under the CC-BY-NC-SA-4.0 license.

Configurations

check out class DiffuserConfig and class SceneDiffuserConfig in algo_config.py for algorithm configs, trajdata_nusc_config.py for dataset configs, and scene_edit_config.py for rollout configs (including changing the guidance used during denoising).

References

If you find this repo useful, please consider to cite our relevant work:

@INPROCEEDINGS{10161463,
  author={Zhong, Ziyuan and Rempe, Davis and Xu, Danfei and Chen, Yuxiao and Veer, Sushant and Che, Tong and Ray, Baishakhi and Pavone, Marco},
  booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)}, 
  title={Guided Conditional Diffusion for Controllable Traffic Simulation}, 
  year={2023},
  volume={},
  number={},
  pages={3560-3566},
  doi={10.1109/ICRA48891.2023.10161463}}

@inproceedings{
zhong2023languageguided,
title={Language-Guided Traffic Simulation via Scene-Level Diffusion},
author={Ziyuan Zhong and Davis Rempe and Yuxiao Chen and Boris Ivanovic and Yulong Cao and Danfei Xu and Marco Pavone and Baishakhi Ray},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=nKWQnYkkwX}
}

ctg's People

Contributors

aiasd avatar borisivanovic avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.