Coder Social home page Coder Social logo

zhengchuanpan / gman Goto Github PK

View Code? Open in Web Editor NEW
394.0 11.0 104.0 4.84 MB

GMAN: A Graph Multi-Attention Network for Traffic Prediction (GMAN, https://fanxlxmu.github.io/publication/aaai2020/) was accepted by AAAI-2020.

License: Apache License 2.0

Python 100.00%
gman traffic-prediction aaai2020

gman's Introduction

GMAN: A Graph Multi-Attention Network for Traffic Prediction (AAAI-2020)

This is the implementation of Graph Multi-Attention Network in the following paper:
Chuanpan Zheng, Xiaoliang Fan*, Cheng Wang, and Jianzhong Qi. "GMAN: A Graph Multi-Attention Network for Traffic Prediction", Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence (AAAI-20), 2020, 34(01): 1234-1241.

Data

The datasets are available at Google Drive or Baidu Yun, provided by DCRNN, and should be put into the corresponding data/ folder.

Requirements

Python 3.7.10, tensorflow 1.14.0, numpy 1.16.4, pandas 0.24.2

Results

Third-party re-implementations

A Pytorch implementaion by VincLee8188 is available at GMAN-Pytorch.

Citation

If you find this repository useful in your research, please cite the following paper:

@inproceedings{ GMAN-AAAI2020,
  author     = "Chuanpan Zheng and Xiaoliang Fan and Cheng Wang and Jianzhong Qi"
  title      = "GMAN: A Graph Multi-Attention Network for Traffic Prediction",
  booktitle  = "AAAI",
  pages      = "1234--1241",
  year       = "2020"
}

gman's People

Contributors

fanxlxmu avatar zhengchuanpan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gman's Issues

求指点:如何解决AttributeError: 'numpy.bytes_' object has no attribute 'delta'

utils.py中 timeofday = (Time.hour * 3600 + Time.minute * 60 + Time.second) // Time.freq.delta.total_seconds() 这一句报错

Traceback (most recent call last):
File "/Users/crowd/PycharmProjects/GMAN/METR/train.py", line 55, in
mean, std) = utils.loadData(args)
File "/Users/crowd/PycharmProjects/GMAN/METR/utils.py", line 73, in loadData
timeofday = (Time.hour * 3600 + Time.minute * 60 + Time.second) // Time.freq.delta.total_seconds()
AttributeError: 'numpy.bytes_' object has no attribute 'delta'

我没有修改过作者源码 请问这个问题大家是怎么解决的

Validation error nan

I've been trying to run the MATR example, and from the first iteration I'm receving validation error "nan", as a consequence the model stops learning after 10 iterations. Is there are problem with the code?

ZeroDivisionError: float division by zero

在生成SE时,preprocess_transition_probs()的normalized_probs = [float(u_prob)/norm_const for u_prob in unnormalized_probs]出现问题:ZeroDivisionError: float division by zero

有谁复现出了PeMS上的结果吗?

我在TensorFlow2上兼容模式跑的,还把patience调成了20,测试集平均MAE为1.66,与报告的水平有差距

testing time: 36.1s
MAE RMSE MAPE
train 1.32 2.87 2.78%
val 1.59 3.72 3.61%
test 1.66 3.82 3.74%
performance in each prediction step
step: 01 0.99 1.88 1.96%
step: 02 1.21 2.47 2.50%
step: 03 1.38 2.97 2.93%
step: 04 1.52 3.35 3.30%
step: 05 1.62 3.65 3.61%
step: 06 1.71 3.90 3.86%
step: 07 1.78 4.09 4.08%
step: 08 1.85 4.25 4.26%
step: 09 1.90 4.38 4.42%
step: 10 1.95 4.49 4.55%
step: 11 1.99 4.59 4.67%
step: 12 2.03 4.67 4.78%
average: 1.66 3.72 3.74%
total time: 3.4min

Reproducing the results

Hello,
Thank you very much for sharing your code with the community.

After many attempts with different hyperparameters we have not been able to reproduce any results from the paper (or even get close). Was anyone been able to reproduce the results or do the authors have any pointers in how to achieve this?
Thank you.

Time Features - Ordinality

Doesn't the way time features were encoded introduce ordinality?

For example, if Sunday is encoded as 1 and Thursday is encoded as 5 - doesn't that let the model think Thursday is more important than Sunday.

Is this understanding correct? If yes, could you help to understand why that decision was taken during model design?

The length of the PEMS data

In the paper, it mentioned that "traffic speed prediction on the PeMS dataset (Li et al. 2018b)), which contains 6 months of data recorded by 325 traffic sensors ranging from January 1st, 2017 to June 30th, 2017 in the Bay Area." But in the referred paper, it said the data was collected from Jan 1st 2017 to May 31th 2017. Can you provide the 6 month data instead?

The data shape is different from DCRNN, GraphWavnet.

The previous works data shape is:
train shape X(36465, 12, 325, 2) Y(36465, 12, 325, 2)
val shape X(5209, 12, 325, 2) Y(5209, 12, 325, 2)
test shape X(10419, 12, 325, 2) Y(10419, 12, 325, 2)
Your is:
trainX: (36458, 12, 325) trainY: (36458, 12, 325)
valX: (5189, 12, 325) valY: (5189, 12, 325)
testX: (10400, 12, 325) testY: (10400, 12, 325)
I'm confused about it.

数据

对这个工作非常感兴趣,请问能否提供下完整的数据,包括SE?

请问下loadData()里面 Time = df.index报错是为什么啊?

我使用的DCRNN下载下来的METR.h5文件,使用pandas对其进行读取,生成Time Embedding时,代码中TIME = df.index报错,如下:
ssh://[email protected]:22/home/tank/anaconda3/envs/lpb/bin/python3.6 -u /home/tank/lxl/GMAN/GMAN-master/METR/analyzeData.py
Traceback (most recent call last):
File "/home/tank/lxl/GMAN/GMAN-master/METR/analyzeData.py", line 37, in
print(df.index)
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 852, in repr
attrs = self._format_attrs()
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/indexes/datetimelike.py", line 381, in _format_attrs
freq = self.freqstr
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/indexes/extension.py", line 54, in fget
result = getattr(self.data, name)
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/arrays/datetimelike.py", line 1104, in freqstr
return self.freq.freqstr
AttributeError: 'numpy.bytes
' object has no attribute 'freqstr'

Process finished with exit code 1

请问是我的数据集不对吗?还是我的Pandas版本(1.1.4)不对啊,为什么无法获取到这个index呢?万分感谢

HELP,NotImplementedError: reshaping is not supported for Index objects

Traceback (most recent call last):
File "D:/GitHub源代码/GMAN-master/GMAN-master/METR/train.py", line 55, in
mean, std) = utils.loadData(args)
File "D:\GitHub源代码\GMAN-master\GMAN-master\METR\utils.py", line 72, in loadData
dayofweek = np.reshape(Time.weekday, newshape = (-1, 1))
File "D:\Software\Anaconda3\envs\tensorflow\lib\site-packages\numpy\core\fromnumeric.py", line 232, in reshape
return _wrapfunc(a, 'reshape', newshape, order=order)
File "D:\Software\Anaconda3\envs\tensorflow\lib\site-packages\numpy\core\fromnumeric.py", line 57, in _wrapfunc
return getattr(obj, method)(*args, **kwds)
File "D:\Software\Anaconda3\envs\tensorflow\lib\site-packages\pandas\core\indexes\base.py", line 1149, in reshape
raise NotImplementedError("reshaping is not supported "
NotImplementedError: reshaping is not supported for Index objects

about model performance

great work!
I have a question about the computation of attention coefficient. Did you ever do experience to compare the model performance with STE block and without STE block?

Masking in Loss function

I have seen various masking applications in the code yet it wasn't mentioned in paper. Especially in the mae_loss(), masking is applied. What is the purpose of this application?

group spatial attention

论文提到采用了把节点分组的方式,理乱上减少了计算的复杂度,请问在代码中计算空间注意力这一块儿,哪里体现了分组计算呢?

一些关于GMAN的问题

model.py的line142这里x和y的shape不应该一致吗? 还有请问楼主tf是啥版本的。感谢
image

Inconsistencies with the paper

Hello, firstly I would like to thank you for sharing the code. I was looking at the Spatial Attention component (line 56 in model.py) and I've noticed some differences from what is presented in the paper:

  1. I can not find where you're splitting the vertices into G partitions (and doing the intra/inter group attention). As far as I can understand the spatialAttention function does only the intra-group spatial attention without any restrictions.
  2. After you're computing eq 7 (line 86 in model.py) the output is projected again using 2 FC layers, which in the paper are not described. What is the reason for it?
  3. Looking at eq 7 the input of function f3 is the previous hidden representation where in you're code you're also using the static graph embeddings (e_{v,tj})

Looking forward for your reply.

去除transformAttention评估模型时,是怎么处理维度不匹配的问题的?

为了研究模型中各部分的影响,作者去除了transformAttention部分来评估模型性能,如果直接将此部分的代码注释,在执行decoder部分的spatialAttention的 X = tf.concat((X, STE), axis = -1)这一行代码会报维度不匹配错误,所以想请教作者在去除transformAttention模块来评估模型的性能时,是如何处理维度的

A doubt

Why GAMN good at long-term forecasting, but not so obvious for short-term forecasting.
it is because Temporal attention? Looking forward to your answer,thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.