Coder Social home page Coder Social logo

groupnet's Introduction

GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning

Official PyTorch code for CVPR'22 paper "GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning".

Abstract: Demystifying the interactions among multiple agents from their past trajectories is fundamental to precise and interpretable trajectory prediction. However, previous works only consider pair-wise interactions with limited relational reasoning. To promote more comprehensive interaction modeling for relational reasoning, we propose GroupNet, a multiscale hypergraph neural network, which is novel in terms of both interaction capturing and representation learning. From the aspect of interaction capturing, we propose a trainable multiscale hypergraph to capture both pair-wise and group-wise interactions at multiple group sizes. From the aspect of interaction representation learning, we propose a three-element format that can be learnt end-to-end and explicitly reason some relational factors including the interaction strength and category. We apply GroupNet into both CVAE-based prediction system and previous state-of-the-art prediction systems for predicting socially plausible trajectories with relational reasoning. To validate the ability of relational reasoning, we experiment with synthetic physics simulations to reflect the ability to capture group behaviors, reason interaction strength and interaction category. To validate the effectiveness of prediction, we conduct extensive experiments on three real-world trajectory prediction datasets, including NBA, SDD and ETH-UCY; and we show that with GroupNet, the CVAE-based prediction system outperforms state-of-the-art methods. We also show that adding GroupNet will further improve the performance of previous state-of-the-art prediction systems.

We give an example of trajectories predicted by our model and the corresponding ground truth on NBA dataset as following:

News (2022.6.28)

Our extention version "DynGroupNet" is available on the arxiv (Link). Code will come afterwards.

Update

We notice there is a mistake on scaling losses and we already fixed it. We update the pretrain model and the performance is further improved.

Requirement

Recommend Environment

  • Tested OS: Linux / RTX 3090
  • Python == 3.7.11
  • PyTorch == 1.8.1+cu111

Dependencies

Install the dependencies from the requirements.txt:

pip install -r requirements.txt

Data preparation

You can directly use the preprocessed data of NBA SportsVU in datasets/nba including train.npy and test.npy as training and testing data. If you want to sample trajecory by yourself, simply download datafiles from 'NBA-Player-Movements' and put them in datasets/nba/source. Then run:

python generate_dataset.py

Training

To train a GroupNet model on the NBA dataset, simply run:

python train_hyper_nba.py --gpu {your_gpu_id}

Training models will be saved in 'saved_models/nba'

Evaluating

To evalutate the model performance, simply run:

python test_nba.py --gpu {your_gpu_id} --model_names {your_model_name}

We provide a pretrained model which is better than the performance reported in our paper in 'saved_models/nba/pretrain.p'.

To evalutate the pretrained model, simply run:

python test_nba.py --gpu {your_gpu_id} --model_names pretrain

Acknowledgement

We thanks for the NBA data and processing code provided by 'NBA-Player-Movements'. We also thanks for the part of code of NRI and NMMP, whose github repo is NRI code and NMMP code. We thank the authors for releasing their code.

Citation

If you use this code, please cite our paper:

@InProceedings{xu2022GroupNet,
author = {Xu, Chenxin and Li, Maosen and Ni, Zhenyang and Zhang, Ya and Chen, Siheng},
title = {GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning},
booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022}
}

groupnet's People

Contributors

sjtuxcx avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

groupnet's Issues

Wrong loss calculation

Dear authors,

There seems a wrong calculation of loss in your code. You averaged the loss (pred, recover, diverse) over batch size and pred.shape[0], but pred.shape[0] already includes batch size. However, the sequence dimension was not averaged. Also, you multiplied agent_num for kl loss. The performance can be further improved after correction.

picture in paper

hello, authors. what you have done in paper is really impressive, it is very useful for my study. And i am wondering how to make the network image in your paper, if it is handy for you.Thanks!

Question about Table 2

Hello, thank you for your great research!
I have questions about Table 2(you evaluated ADE & FDE scores(NBA dataset) in several models(ex. Social-LSTM, Social-GAN, SocialSTGCNN, STGAT, ..,PECNet, NMMP)).

  1. What data did you use when Social- Model? Did you follow the social-gan dataset format?
  2. What code did you use to evaluate models?

Each Model has a different input shape, input format, and output shape, so I'm confused about what data should I use.

Please give me information about my questions.
Thank you,
JiyouSeo.

Code for other experiment results

Hi, thank you very much for your contribution to the trajectory prediction community. Your groupNet is wonderful. Regarding the reusability of this code, I wonder when you plan to publish the code related to other two data sets, ETH/UCY and SDD? In other issues, you mentioned that you will release the relevant code after DynGroupNet has been accepted. I've read your paper "Dynamic-group-aware networks for multi-agent trajectory prediction with relational reasoning" that was published at neural networks in 7 November 2023, but there is still no clues about the code. So I'm wondering when do you plan to release the code for the ETH/UCY&SDD dataset? Whether for GroupNet or DynGroupNet.

Handling variable number of pedestrians

Hello,
This is really exciting work and thanks for open-sourcing the code. I wanted to use the interaction module on other datasets (e.g. ETH/UCY, TrajNet++) which have a variable number of pedestrians in each sample. Example: Sample 1 can have five pedestrians, and Sample 2 can have four pedestrians. How can I use the MS_HGNN module in this case?

I believe the current implementation assumes that the number of pedestrians is constant for each sample across entire dataset. This is true for NBA but not for the other datasets.

Thanks!

NFL dataset

Hi, thanks for your excellent work, may I ask where did you download the NFL football dataset?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.