Coder Social home page Coder Social logo

memonet's Introduction

Remember Intentions: Retrospective-Memory-based Trajectory Prediction

Official PyTorch code for CVPR'22 paper "Remember Intentions: Retrospective-Memory-based Trajectory Prediction".

[Paper] [Zhihu]

system design

Abstract: To realize trajectory prediction, most previous methods adopt the parameter-based approach, which encodes all the seen past-future instance pairs into model parameters. However, in this way, the model parameters come from all seen instances, which means a huge amount of irrelevant seen instances might also involve in predicting the current situation, disturbing the performance. To provide a more explicit link between the current situation and the seen instances, we imitate the mechanism of retrospective memory in neuropsychology and propose MemoNet, an instance-based approach that predicts the movement intentions of agents by looking for similar scenarios in the training data. In MemoNet, we design a pair of memory banks to explicitly store representative instances in the training set, acting as prefrontal cortex in the neural system, and a trainable memory addresser to adaptively search a current situation with similar instances in the memory bank, acting like basal ganglia. During prediction, MemoNet recalls previous memory by using the memory addresser to index related instances in the memory bank. We further propose a two-step trajectory prediction system, where the first step is to leverage MemoNet to predict the destination and the second step is to fulfill the whole trajectory according to the predicted destinations. Experiments show that the proposed MemoNet improves the FDE by 20.3%/10.2%/28.3% from the previous best method on SDD/ETH-UCY/NBA datasets. Experiments also show that our MemoNet has the ability to trace back to specific instances during prediction, promoting more interpretability.

We give an example of trajectories predicted by our model and the corresponding ground truth as following:

system design

Below is an example of prediction interpretability where the first column stands for the current agent. The last three columns stand for the memory instances found by the current agent. system design

[2022/09] Update: ETH's code & model are available!

You can find the code and the instructions in the ETH folder.

Installation

Environment

  • Tested OS: Linux / RTX 3090
  • Python == 3.7.9
  • PyTorch == 1.7.1+cu110

Dependencies

Install the dependencies from the requirements.txt:

pip install -r requirements.txt

Pretrained Models

We provide a complete set of pre-trained models including:

  • intention encoder-decoder:
  • learnable addresser:
  • generated memory bank:
  • fulfillment encoder-decoder:

You can download the pretrained models/data from here.

File Structure

After the prepartion work, the whole project should has the following structure:

./MemoNet
├── ReadMe.md
├── data                            # datasets
│   ├── test_all_4096_0_100.pickle
│   └── train_all_512_0_100.pickle
├── models                          # core models
│   ├── layer_utils.py
│   ├── model_AIO.py
│   └── ...
├── requirements.txt
├── run.sh
├── sddloader.py                    # sdd dataloader
├── test_MemoNet.py                 # testing code
├── train_MemoNet.py                # training code
├── trainer                         # core operations to train the model
│   ├── evaluations.py
│   ├── test_final_trajectory.py
│   └── trainer_AIO.py
└── training                        # saved models/memory banks
    ├── saved_memory
    │   ├── sdd_social_filter_fut.pt
    │   ├── sdd_social_filter_past.pt
    │   └── sdd_social_part_traj.pt
    ├── training_ae
    │   └── model_encdec
    ├── training_selector
    │   ├── model_selector
    │   └── model_selector_warm_up
    └── training_trajectory
        └── model_encdec_trajectory

Training

Important configurations.

  • --mode: verify the current training mode,
  • --model_ae: pretrained model path,
  • --info: path name to store the models,
  • --gpu: number of devices to run the codes,

Training commands.

bash run.sh

Reproduce

To get the reported results, following

python test_MemoNet.py --reproduce True --info reproduce --gpu 0

And the code will output:

./training/training_trajectory/model_encdec_trajectory
Test FDE_48s: 12.659514427185059 ------ Test ADE: 8.563031196594238
----------------------------------------------------------------------------------------------------

Acknowledgement

Thanks for the framework provided by Marchetz/MANTRA-CVPR20, which is source code of the published work MANTRA in CVPR-2020. The github repo is MANTRA code. We borrow the framework and interface from the code.

We also thank for the pre-processed data provided by the works of PECNet (paper,code).

Citation

If you use this code, please cite our paper:

@InProceedings{MemoNet_2022_CVPR,
author = {Xu, Chenxin and Mao, Weibo and Zhang, Wenjun and Chen, Siheng},
title = {Remember Intentions: Retrospective-Memory-based Trajectory Prediction},
booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022}
}

memonet's People

Contributors

wbmao avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

memonet's Issues

Columns of ETH dataset

Hello! I have been trying to find information on the columns of the ETH dataset and what they represent to build out my own segmentation model to feed into MemoNet. Do you happen to know what the columns represent, or somewhere to find that information?

Thank you!
(i.e. 0.0 1.0 Pedestrian -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 1.41 -1.0 -5.68 -1.0 for biwi_hotel_train.txt)

Return predicted trajectory coordinates to the map

Hi, @WBMao @sjtuxcx
There is a function visualize_data() in map.py to get past and future trajectory coordinates back onto the map. Could you also explain how to return predicted trajectory coordinates to the map?
PS: this conversion would be equivalent to getting back past_normalized to the past.

	past = torch.stack(data['pre_motion_3D']).cuda()
	future = torch.stack(data['fut_motion_3D']).cuda()
	last_frame = past[:, -1:]
	past_normalized = past - last_frame
	fut_normalized = future - last_frame
	

	past_abs = past.unsqueeze(0).repeat(past.size(0), 1, 1, 1)
	past_centroid = past[:, -1:, :].unsqueeze(1)
	past_abs = past_abs - past_centroid

	scale = 1
	if self.cfg.scale.use:
		scale = torch.mean(torch.norm(past_normalized[:, 0], dim=1)) / 3
		if scale<self.cfg.scale.threshold:
			scale = 1
		else:
			if self.cfg.scale.type == 'divide':
				scale = scale / self.cfg.scale.large
			elif self.cfg.scale.type == 'minus':
				scale = scale - self.cfg.scale.large
		if self.cfg.scale.type=='constant':
			scale = self.cfg.scale.value
		past_normalized = past_normalized / scale
		past_abs = past_abs / scale

	if self.cfg.rotation:
		past_normalized, fut_normalized, past_abs = self.rotate_traj(past_normalized, fut_normalized, past_abs)

Slow inference time

When running the model on my RTX6000 ADA, the inference time is near 200ms for a batch of 1. What would cause such a large deviation from the reported 55ms on an RTX3090?

SDD data with only Pedestrians?

Hi author, thanks for sharing your work!

For SDD experiment, since you use the pre-processed data from PECNet (the same author of Y-Net), can I assume that your training data only contains Pedestrians?

The following is from CPVR2022 work "End-to-End Trajectory Distribution Prediction Based on Occupancy Grid Maps":

Most of our tests are conducted on the Stanford Drone Dataset (SDD) [40] provides top-down RGB videos captured on the Stanford University campus by drones at 60 different scenes, containing annotated trajectories of more than 20,000 targets such as pedestrians, bicyclists and cars. Early works [5, 23, 43] consider all trajectories in SDD and subsequent works [27–29, 56] focus on pedestrian trajectories using the TrajNet benchmark [42]. On these two splits, we report the results of predicting the 12-step future with the 8-step history with 0.4 seconds step interval.

You can see how Y-Net filter out non-Pedestrians here:
https://github.com/HarshayuGirase/Human-Path-Prediction/blob/a345ada1557f53dc2acb4c07f71d340d366980da/ynet/utils/preprocessing.py#L43

Thanks for your clarification!

mode

'intention', 'addressor_warm', 'addressor', 'trajectory'分别代表什么过程

Generate Memory On Custom Dataset

It seems that there is no code for memory bank generation. I want to train MemoNet on custom trajectory dataset, is it possible for you to release a version including memory generation? It will be greatly appreciated!

Thanks for your great work!

Other datasets

Hi thanks for sharing your work and acknowledging our MANTRA repository! We are glad you found it useful and that you managed to improve upon it!
I noticed that you shared the dataloader for SDD, are you planning on sharing also the ones for ETH/UCY and NBA with any related code/model?

Thank you very much

Federico

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.