Coder Social home page Coder Social logo

khaled-alkilane / gsta Goto Github PK

View Code? Open in Web Editor NEW
5.0 1.0 1.0 47.97 MB

GSTA: Gated Spatial-Temporal Attention Approach for Travel Time Prediction

License: MIT License

Jupyter Notebook 100.00%
travel-time-prediction deep-learning attention-mechanism gated-attention gated-neural-network feature-selection

gsta's Introduction

GSTA

GSTA Architecture

Spatial-temporal Attention

Spatial-Temporal_Attention

Data

  • A sample of 81K trips is provided for each of the NYC and Chengdu Taxi datasets in folders (NYC Data, Chengdu Data).
  • The data samples are already pre-processed (Data Cleaning, Feature Engineering,... etc) and randomly split into train (X_train, Y_train), validation (X_val, Y_val), and test (X_test, Y_test).
  • The implementation of the prediction model for each dataset is given in a separate jupyter notebook (GSTA on NYC.ipynb, and GSTA on Chengdu.ipynb).

Each data sample is a CSV file. The key contains:

  • Location Features: 'pickup_longitude', 'pickup_latitude', 'dropoff_longitude', 'dropoff_latitude', 'center_latitude', 'center_longitude', 'dropoff_pca0', 'dropoff_pca1', 'pickup_pca0', 'pickup_pca1'
  • Cluster Features: 'pickup_cluster', 'dropoff_cluster', 'pickup_counts_on_clusterid', 'dropoff_counts_on_clusterid'
  • Geo-hash Features: 'pickup_geohash', 'dropoff_geohash'
  • Date/Time Features: 'DayofMonth_sin', 'DayofMonth_cos', 'Hour_sin', 'Hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'Weekend_day', 'Work_day'
  • Weather Features: 'tempm', 'dewptm', 'hum', 'rain', 'snow', 'wdird', 'vism', 'fog', 'thunder', 'tornado', 'conds_Clear', 'conds_Haze', 'conds_Heavy Rain', 'conds_Heavy Snow', 'conds_Light Rain', 'conds_Light Snow'
  • Distance: 'distance_haversine', 'distance_dummy_manhattan'
  • Direction: 'direction'
  • Speed: 'avg_speed_KMperHour'
  • Public Holiday: 'Public_Holiday'
  • Peak period: 'Peak_Hour'
  • Travel Time: 'trip_duration'

Parameters:

(These parameteres are tuned with whole dataset, you can change them manually)
The default parameters are:

  • optimizer=Adam(lr=0.001).
  • loss='mean_absolute_error'
  • metrics=['mae','mape']
  • Dropout=0.2
  • epochs = 50
  • batch_size = 256
  • kernel_regularizer=l2(0.001)
  • Activation('elu')
  • BatchNormalization(epsilon=1e-06, momentum=0.98)
  • kernel_initializer="he_uniform"
  • num_heads = 4 , which is the number of heads in Multi-Head attention.

Results

NYC_Predictions Chengdu_Predictions Abnormal_Weather_Predictions_NYC

Best Model

The best model for each data during training phase is saved to folder "Models" as hdf5 file.

Dependencies:

Keras 2.4.3, Tensorflow 2.3.0, Bokeh 2.2.1, Numpy 1.19.3, Pandas 1.1.5, Sklearn.

BibTeX Citation

If you use our paper in a scientific publication, we would appreciate using the following citations:

@article{Khaled2021,
author = {Khaled, Alkilane and Elsir, Alfateh M Tag and Shen, Yanming},
doi = {10.1007/s00521-021-06560-z},
issn = {1433-3058},
journal = {Neural Computing and Applications},
title = {{GSTA: gated spatial–temporal attention approach for travel time prediction}},
url = {https://doi.org/10.1007/s00521-021-06560-z},
year = {2021}
}

gsta's People

Contributors

khaled-alkilane avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

xuaikun

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.