Skip to content
forked from Echo-Ji/ST-SSL

ST-SSL (STSSL): Spatio-Temporal Self-Supervised Learning for Traffic Flow Forecasting/Prediction

Notifications You must be signed in to change notification settings

whysirier/ST-SSL

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ST-SSL: Spatio-Temporal Self-Supervised Learning for Traffic Prediction

PWC

PWC

PWC

PWC

This is a Pytorch implementation of ST-SSL in the following paper:

framework

new 22/04/2023: The post of this paper is selected for a headline tweet by PaperWeekly and received nearly 7,000 reads. PaperWeekly is a leading AI academic platform in China.

new 09/02/2023: The video replay of academic presentation at AAAI 2023.

new 04/02/2023: J. Ji is invited to give a talk at AAAI 2023 Beijing Pre-Conference. The talk is about Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction.

Requirement

We build this project by Python 3.8 with the following packages:

numpy==1.21.2
pandas==1.3.5
PyYAML==6.0
torch==1.10.1

Datasets

The datasets range from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}. You can download them from GitHub repo, Beihang Cloud Drive, or Google Drive.

Each dataset is composed of 4 files, namely train.npz, val.npz, test.npz, and adj_mx.npz.

|----NYCBike1\
|    |----train.npz
|    |----adj_mx.npz
|    |----test.npz
|    |----val.npz

train/val/test data is composed of 4 numpy.ndarray objects:

  • x: a 4D tensor of shape (#timeslots, #lookback_window, #nodes, #flow_types)

  • y: a 4D tensor of shape (#timeslots, #predict_horizon, #nodes, #flow_types). x and y are processed as a sliding window view.

  • x_offset: a tensor indicating offsets of x's lookback window. Note that the lookback window of data x is not consistent in time.

  • y_offset: a tensor indicating offsets of y's predict horizon.

For all datasets, previous 2-hour flows as well as previous 3-day flows around the predicted time are used to predict the flows for the next time step.

adj_mx.npz is a symmetric adjacency matrix, taking the value of 0 or 1.

⚠️ Note that all datasets are processed as a sliding window view. Raw data of NYCBike1 and BJTaxi are collected from STResNet. Raw data of NYCBike2 and NYCTaxi are collected from STDN.

Model training and Evaluation

If the environment is ready, please run the following commands to train the model on the specific dataset from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}.

>> cd ST-SSL
>> ./runme 0 NYCBike1   # 0 gives the GPU id

Note that this repo only contains the NYCBike1 data because including all datasets can make this repo heavy.

Cite

If you find the paper useful, please cite the following:

@article{ji2023modeling, 
  title={Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction}, 
  author={Ji, Jiahao and Wang, Jingyuan and Huang, Chao and Wu, Junjie and Xu, Boren and Wu, Zhenhe and Zhang Junbo and Zheng, Yu}, 
  journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 
  volume={37},
  number={4},
  pages={4356-4364},
  year={2023}
}

About

ST-SSL (STSSL): Spatio-Temporal Self-Supervised Learning for Traffic Flow Forecasting/Prediction

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.4%
  • Shell 0.6%