Coder Social home page Coder Social logo

tianjianl / jamstgcn Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 7.28 MB

A PaddlePaddle implementation of STGCN with a few modifications in the model architecture in order to forecast traffic jam.

License: GNU General Public License v3.0

Python 99.65% Shell 0.35%

jamstgcn's Introduction

About

This repository contains the code of a PaddlePaddle 2.2 implementation of STGCN based on the paper Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting, with a few modifications in the model architecture to tackle with traffic jam forecasting problems.

Forecasting traffic jams is not that similar to forecasting traffic flow, which we have a redundant amount of data. However, jams only occur in the largest cities and often only during peak hours, resulting in a unbalanced dataset. In order to study the jam patterns, we select roads in Haidian District, Beijing that have much more jam than others.

Related Papers

Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting https://arxiv.org/abs/1709.04875

Semi-Supervised Classification with Graph Convolutional Networks https://arxiv.org/abs/1609.02907 (GCN)

Inductive Representation Learning on Large Graphs https://arxiv.org/abs/1706.02216 (GraphSAGE)

Graph Attention Networks https://arxiv.org/abs/1710.10903 (GAT)

Bag of Tricks for Node Classification with Graph Neural Networks https://arxiv.org/pdf/2103.13355.pdf (BoT)

Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting https://ojs.aaai.org//index.php/AAAI/article/view/3881 (ASTGCN)

Structural Modifications

Graph operation

The original STGCN model facilitates 1-st order ChebyConv and GCN as the graph operation. In our model we conducted experiments on one spectral method(GCN) and two spatial methods(GAT, GraphSAGE)

Residual connection in graph convolution layer

Graph Neural Networks often suffer from oversmoothing: as the model becomes deeper, the representations of nodes tend to become similar to each other due to being repeatedly aggregated. Adding a residual connection mitigates oversmoothing by adding the input unsmoothed features directly to the output of graph convolution operation. Furthermore, the connection helps against gradient instablities.

Incorporation of historical jam pattern

Jam status often follow daily patterns. In order to let the model learn historical patterns, we directly feed the model historical jam data with the same hour aligned. For example, if we want to predict the traffic status at 8PM. 30, Nov, 2021, we feed the model the 8PM traffic status in the past 12 days directly through a graph convolution layer, then concat it with the output of the S-T convolution blocks to generate the input of the final classifying layer.

Multi-Step Prediction

The model is implemented to predict the jam status of several future time steps. First we feed the input data into the model to generate prediction of the first future time step. Then we concat the predicted status with the original input, feed to the model to generate the prediction of the next time step and so on.

Classification

The original STGCN model was a regression model, optimizing a mean squared loss. Our traffic jam status has four classes: 1 -- smooth traffic; 2 -- temperate jam; 3 -- moderate jam; 4 -- heavy jam. So we changed it into a softmax with cross entropy classification model. Because in most of the cases, the traffic are smooth which makes label 1 dominates the others. We use a weighted cross entropy loss to punish incorrect classifications of 2, 3 and 4 more serverely.

Requirements

You can use pip to install the requirements:

sh requirement.sh

Experiments

Right now I am still updating the experiments, adding new blocks, trying out new ideas.
Here's a example of what a training epoch should look like right now, the numbers are the cross entropy loss of the model. The current model achieves an accuracy over 80% in prediction of traffic jams of the next 4 hours.

training

Future Directions

I think the self-distillation methods has their potential in imporving long term prediction by first allowing the teacher method to learn historic patterns to make a raw prediction of whether a road segment is jammed or not, then the student model utilizes graph structured data to use the nearby traffic conditions to make a more educated guess of the future jam status.

Also pretraining techniques also has their potential in generating spatial-temporal representations for road segments.

jamstgcn's People

Contributors

tianjianl avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.