Coder Social home page Coder Social logo

ll-c8 / groupnet Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mediabrain-sjtu/groupnet

0.0 0.0 0.0 180.49 MB

[CVPR22] GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning

License: MIT License

Python 100.00%

groupnet's Introduction

GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning

Official PyTorch code for CVPR'22 paper "GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning".

Abstract: Demystifying the interactions among multiple agents from their past trajectories is fundamental to precise and interpretable trajectory prediction. However, previous works only consider pair-wise interactions with limited relational reasoning. To promote more comprehensive interaction modeling for relational reasoning, we propose GroupNet, a multiscale hypergraph neural network, which is novel in terms of both interaction capturing and representation learning. From the aspect of interaction capturing, we propose a trainable multiscale hypergraph to capture both pair-wise and group-wise interactions at multiple group sizes. From the aspect of interaction representation learning, we propose a three-element format that can be learnt end-to-end and explicitly reason some relational factors including the interaction strength and category. We apply GroupNet into both CVAE-based prediction system and previous state-of-the-art prediction systems for predicting socially plausible trajectories with relational reasoning. To validate the ability of relational reasoning, we experiment with synthetic physics simulations to reflect the ability to capture group behaviors, reason interaction strength and interaction category. To validate the effectiveness of prediction, we conduct extensive experiments on three real-world trajectory prediction datasets, including NBA, SDD and ETH-UCY; and we show that with GroupNet, the CVAE-based prediction system outperforms state-of-the-art methods. We also show that adding GroupNet will further improve the performance of previous state-of-the-art prediction systems.

We give an example of trajectories predicted by our model and the corresponding ground truth on NBA dataset as following:

News (2022.6.28)

Our extention version "DynGroupNet" is available on the arxiv (Link). Code will come afterwards.

Requirement

Recommend Environment

  • Tested OS: Linux / RTX 3090
  • Python == 3.7.11
  • PyTorch == 1.8.1+cu111

Dependencies

Install the dependencies from the requirements.txt:

pip install -r requirements.txt

Data preparation

You can directly use the preprocessed data of NBA SportsVU in datasets/nba including train.npy and test.npy as training and testing data. If you want to sample trajecory by yourself, simply download datafiles from 'NBA-Player-Movements' and put them in datasets/nba/source. Then run:

python generate_dataset.py

Training

To train a GroupNet model on the NBA dataset, simply run:

python train_hyper_nba.py --gpu {your_gpu_id}

Training models will be saved in 'saved_models/nba'

Evaluating

To evalutate the model performance, simply run:

python test_nba.py --gpu {your_gpu_id} --model_names {your_model_name}

We provide a pretrained model which is slightly better than the performance reported in our paper in 'saved_models/nba/pretrain.p'.

To evalutate the pretrained model, simply run:

python test_nba.py --gpu {your_gpu_id} --model_names pretrain

Acknowledgement

We thanks for the NBA data and processing code provided by 'NBA-Player-Movements'. We also thanks for the part of code of NRI and NMMP, whose github repo is NRI code and NMMP code. We thank the authors for releasing their code.

Citation

If you use this code, please cite our paper:

@InProceedings{xu2022GroupNet,
author = {Xu, Chenxin and Li, Maosen and Ni, Zhenyang and Zhang, Ya and Chen, Siheng},
title = {GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning},
booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022}
}

groupnet's People

Contributors

sjtuxcx avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.