This is a Pytorch implementation of ST-SSL in the following paper:
- J. Ji, J. Wang, C. Huang, et al. "Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction". in AAAI 2023.
22/04/2023: The post of this paper is selected for a headline tweet by PaperWeekly and received nearly 7,000 reads. PaperWeekly is a leading AI academic platform in China.
09/02/2023: The video replay of academic presentation at AAAI 2023.
04/02/2023: J. Ji is invited to give a talk at AAAI 2023 Beijing Pre-Conference. The talk is about Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction.
We build this project by Python 3.8 with the following packages:
numpy==1.21.2
pandas==1.3.5
PyYAML==6.0
torch==1.10.1
The datasets range from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}
. You can download them from GitHub repo, Beihang Cloud Drive, or Google Drive.
Each dataset is composed of 4 files, namely train.npz
, val.npz
, test.npz
, and adj_mx.npz
.
|----NYCBike1\
| |----train.npz
| |----adj_mx.npz
| |----test.npz
| |----val.npz
train/val/test data is composed of 4 numpy.ndarray
objects:
-
x
: a 4D tensor of shape (#timeslots, #lookback_window, #nodes, #flow_types) -
y
: a 4D tensor of shape (#timeslots, #predict_horizon, #nodes, #flow_types).x
andy
are processed as asliding window view
. -
x_offset
: a tensor indicating offsets ofx
's lookback window. Note that the lookback window of datax
is not consistent in time. -
y_offset
: a tensor indicating offsets ofy
's predict horizon.
For all datasets, previous 2-hour flows as well as previous 3-day flows around the predicted time are used to predict the flows for the next time step.
adj_mx.npz
is a symmetric adjacency matrix, taking the value of 0 or 1.
sliding window view
. Raw data of NYCBike1 and BJTaxi are collected from STResNet. Raw data of NYCBike2 and NYCTaxi are collected from STDN.
If the environment is ready, please run the following commands to train the model on the specific dataset from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}
.
>> cd ST-SSL
>> ./runme 0 NYCBike1 # 0 gives the GPU id
Note that this repo only contains the NYCBike1 data because including all datasets can make this repo heavy.
If you find the paper useful, please cite the following:
@article{ji2023modeling,
title={Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction},
author={Ji, Jiahao and Wang, Jingyuan and Huang, Chao and Wu, Junjie and Xu, Boren and Wu, Zhenhe and Zhang Junbo and Zheng, Yu},
journal={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={37},
number={4},
pages={4356-4364},
year={2023}
}