Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer refactorisation #1915

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

JanFidor
Copy link
Contributor

@JanFidor JanFidor commented Jul 23, 2023

Fixes #601 #672

Summary

  • Added teacher forcing to the transformer model
  • Added auto-regression for inference
  • Changed the normalization factor to math.sqrt(self.d_model) (I tested deleting it altogether, but SunspotsDataset backtest MAPE was almost 2 times higher then)

Other Information

One of the tricky parts was deciding on how to use encoders -> channel dimension for source consists of both (predicted_series and past covariates), while target only has the predicted_series. I decided to use 0 padding to substitute for the missing past covariates and use the same Linear layer for both. It seemed to give a little better results and is more in line with the original implementation.

When it comes to probabilistic forecasting on inference, I decided to just take an average over probabilistic dimension

@JanFidor JanFidor requested a review from dennisbader as a code owner July 23, 2023 13:39
@JanFidor
Copy link
Contributor Author

Hi @dennisbader , I've double checked the failing tests and everything seems to be running fine on my end locally. It also looks like the errors are cause by some weird pytorch interactions and I'm not 100% sure if the code is at fault or if CI/CD is flaky

@dennisbader
Copy link
Collaborator

Hi @JanFidor and thanks for this PR. We'll have time to review next week.
For the meantime, the other unit test workflows run fine, so there is probably an issue with the changes in this PR.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@JanFidor JanFidor force-pushed the feature/transformer-refactorisation branch from f82ca1c to fd13830 Compare August 7, 2023 13:04
@JanFidor
Copy link
Contributor Author

JanFidor commented Aug 7, 2023

Quick update @dennisbader, I fixed the bug (transformer mask generation was creating tensor incorrectly) and tests are passing now, but it seems like codecov has an a connection problem. (Also, I had a tiny adventure with git history and some unpleasant rebases, but fortunately force-pushing and git reset saved the day 🚀 )

And for a neat conclusion, here's a change in the SunspotsDataset backtest performance (dataset from example notebook with a longer forecast horizon, which is where teacher-forcing should make a difference)

Old implementation performance:
Screenshot from 2023-08-07 17-02-04

New implementation performance:
Screenshot from 2023-08-07 16-52-52

@codecov-commenter
Copy link

codecov-commenter commented Aug 7, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (a8a094a) 93.87% compared to head (fa8a152) 93.88%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1915   +/-   ##
=======================================
  Coverage   93.87%   93.88%           
=======================================
  Files         132      132           
  Lines       12673    12692   +19     
=======================================
+ Hits        11897    11916   +19     
  Misses        776      776           
Files Changed Coverage Δ
darts/models/forecasting/transformer_model.py 99.22% <100.00%> (+0.26%) ⬆️

... and 6 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @JanFidor 🚀 (and sorry for the delay)!

I wrote some high level comments before testing the model locally, let me know what you think.

I am not sure to understand your comment about past_covariates being unavailable for the start token, would you mind detailing this a bit?

darts/models/forecasting/transformer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/transformer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/transformer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/transformer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/transformer_model.py Show resolved Hide resolved
data = x_in[0]
pad_size = (0, self.input_size - self.target_size)

# start token consists only of target series, past covariates are substituted with 0 padding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you not include the past covariates in the start token?

darts/models/forecasting/transformer_model.py Outdated Show resolved Hide resolved
@JanFidor
Copy link
Contributor Author

JanFidor commented Aug 27, 2023

Thanks for the review @madtoinou ! About the covariates, as I understand it, they're expected to cover the same range of timestamps as the past_series (timestamps 1:L where L is input_chunk_length). It's not a problem for the TransformerEncoder input, but the TransformerDecoder requires timestamps L:L+H-1 where H is output_chunk_length. I couldn't find a way to access past covariates for L+1:L+H-1. I decided to drop the past_covariates values for the start token to generalize endocing input for TransformerDecoder-> only target channels have non-zero values. I was worried that otherwise encoding for TransformerDecoder could be unstable. For horizon of 2, first token would have diverse covariates, while second one's would be zero. I thought about extending the covariates of first token to the rest of target_series, but I was worried that they would take away from the importance of target_series changes.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adjusting the code.

Before approving, I would like to run some checks to make sure that there is no information leakage and tweak the default parameters (especially for the failing unit-test, it would be great to reach the same accuracy as before).

@@ -8,6 +8,8 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Transformer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe import just generate_square_subsequent_mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading up a little and checking out the implementation, it turns out that generate_square_subsequent_mask is a static method of Transformer. While it is possible to import just it (https://stackoverflow.com/questions/48178011/import-static-method-of-a-class-without-importing-the-whole-class) I don't think it's worth it. That said, I definitely agree that this import is a little intuitive and I think that a nice middle ground would be adding an implementation of generate_square_subsequent_mask to darts/utils as it's a very small function. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT @dennisbader ?

darts/models/forecasting/transformer_model.py Outdated Show resolved Hide resolved
@@ -349,7 +410,7 @@ def __init__(
The multi-head attention mechanism is highly parallelizable, which makes the transformer architecture
very suitable to be trained with GPUs.

The transformer architecture implemented here is based on [1]_.
The transformer architecture implemented here is based on [1]_ and uses teacher forcing [2]_.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that I removed the reference entry by accident, can you please put it back? Instead of the toward data science article, could you please link the torch tutorial : https://github.com/pytorch/examples/tree/main/word_language_model

# Conflicts:
#	darts/models/forecasting/transformer_model.py
…ion' into feature/transformer-refactorisation
@JanFidor
Copy link
Contributor Author

JanFidor commented Sep 28, 2023

Thanks for another review and more great suggestions for improvement @madtoinou ! I'll try to play around with hyperparameters, but after reading the paper on DeepAR (https://arxiv.org/abs/1704.04110) I found a different approach to probabilistic forecasting which might help with the decrease in accuracy. DeepAR uses ancestral sampling (MonteCarlo simulation to generate each of the samples) (Section 3.1 last paragraph + Figure 2). It would require some changes to the implementation but I think the end result would be more expressive if the probability distribution of the errors is known beforehand. Do you think it would be worth it considering the increase in the scope of the PR and time complexity?

@madtoinou
Copy link
Collaborator

Hi @JanFidor,

I would recommend keeping the implementation of the ancestral sampling in the TransformerModel for another PR so that we can more easily compare the change in performance and simplify the review process. Feel free to open an issue to track this improvement!

@JanFidor
Copy link
Contributor Author

Fixed by 2261

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor Transformer model
4 participants