-
Notifications
You must be signed in to change notification settings - Fork 892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/transformer refactorisation #1915
base: master
Are you sure you want to change the base?
Feature/transformer refactorisation #1915
Conversation
Hi @dennisbader , I've double checked the failing tests and everything seems to be running fine on my end locally. It also looks like the errors are cause by some weird pytorch interactions and I'm not 100% sure if the code is at fault or if CI/CD is flaky |
Hi @JanFidor and thanks for this PR. We'll have time to review next week. |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
f82ca1c
to
fd13830
Compare
# Conflicts: # darts/models/forecasting/transformer_model.py
Quick update @dennisbader, I fixed the bug (transformer mask generation was creating tensor incorrectly) and tests are passing now, but it seems like codecov has an a connection problem. (Also, I had a tiny adventure with git history and some unpleasant rebases, but fortunately force-pushing and git reset saved the day 🚀 ) And for a neat conclusion, here's a change in the SunspotsDataset backtest performance (dataset from example notebook with a longer forecast horizon, which is where teacher-forcing should make a difference) |
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #1915 +/- ##
=======================================
Coverage 93.87% 93.88%
=======================================
Files 132 132
Lines 12673 12692 +19
=======================================
+ Hits 11897 11916 +19
Misses 776 776
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @JanFidor 🚀 (and sorry for the delay)!
I wrote some high level comments before testing the model locally, let me know what you think.
I am not sure to understand your comment about past_covariates
being unavailable for the start token, would you mind detailing this a bit?
data = x_in[0] | ||
pad_size = (0, self.input_size - self.target_size) | ||
|
||
# start token consists only of target series, past covariates are substituted with 0 padding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would you not include the past covariates in the start token?
Thanks for the review @madtoinou ! About the covariates, as I understand it, they're expected to cover the same range of timestamps as the past_series (timestamps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adjusting the code.
Before approving, I would like to run some checks to make sure that there is no information leakage and tweak the default parameters (especially for the failing unit-test, it would be great to reach the same accuracy as before).
@@ -8,6 +8,8 @@ | |||
|
|||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn import Transformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe import just generate_square_subsequent_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading up a little and checking out the implementation, it turns out that generate_square_subsequent_mask
is a static method of Transformer.
While it is possible to import just it (https://stackoverflow.com/questions/48178011/import-static-method-of-a-class-without-importing-the-whole-class) I don't think it's worth it. That said, I definitely agree that this import is a little intuitive and I think that a nice middle ground would be adding an implementation of generate_square_subsequent_mask
to darts/utils as it's a very small function. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT @dennisbader ?
@@ -349,7 +410,7 @@ def __init__( | |||
The multi-head attention mechanism is highly parallelizable, which makes the transformer architecture | |||
very suitable to be trained with GPUs. | |||
|
|||
The transformer architecture implemented here is based on [1]_. | |||
The transformer architecture implemented here is based on [1]_ and uses teacher forcing [2]_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that I removed the reference entry by accident, can you please put it back? Instead of the toward data science article, could you please link the torch tutorial : https://github.com/pytorch/examples/tree/main/word_language_model
# Conflicts: # darts/models/forecasting/transformer_model.py
…ion' into feature/transformer-refactorisation
Thanks for another review and more great suggestions for improvement @madtoinou ! I'll try to play around with hyperparameters, but after reading the paper on DeepAR (https://arxiv.org/abs/1704.04110) I found a different approach to probabilistic forecasting which might help with the decrease in accuracy. DeepAR uses ancestral sampling (MonteCarlo simulation to generate each of the samples) (Section 3.1 last paragraph + Figure 2). It would require some changes to the implementation but I think the end result would be more expressive if the probability distribution of the errors is known beforehand. Do you think it would be worth it considering the increase in the scope of the PR and time complexity? |
Hi @JanFidor, I would recommend keeping the implementation of the ancestral sampling in the |
…er-refactorisation
Fixed by 2261 |
Fixes #601 #672
Summary
math.sqrt(self.d_model)
(I tested deleting it altogether, but SunspotsDataset backtest MAPE was almost 2 times higher then)Other Information
One of the tricky parts was deciding on how to use encoders -> channel dimension for source consists of both (predicted_series and past covariates), while target only has the predicted_series. I decided to use 0 padding to substitute for the missing past covariates and use the same Linear layer for both. It seemed to give a little better results and is more in line with the original implementation.
When it comes to probabilistic forecasting on inference, I decided to just take an average over probabilistic dimension