Skip to content

zayunsna/LSTM_by_JAX

Repository files navigation

LSTM_by_JAX

Build the time-series data forcating Using JAX.

# single LSTM layer
# TODO : Multi layer structure / statful LSTM develop.
#      : improving the computing efficiency.
batch_size = 32
input_dim = 1
num_epochs = 10000
lr= 0.001

Epoch 9000, Loss: 0.0004633401695173234
 Training is done.
 Total execution time is 613.8410 sec.

Loss is small enough but forecasting performance is bad. Needs more code implementation, tuning and study more about JAX.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages