-
Notifications
You must be signed in to change notification settings - Fork 896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: torch_forecasting_model load_weights with float16 and float32 #2046
fix: torch_forecasting_model load_weights with float16 and float32 #2046
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Hsinfu for contributing and this fix 🚀
I had 2 minor suggestions regarding adding a unit test and making the dtype mapping more robust against future changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try adding a unit test for this?
The best place for this would be in darts/tests/models/forecasting/test_torch_forecasting_model.
Maybe something like test_loading_weights_with_different_dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_load_weights_with_float32_dtype
@@ -1872,8 +1872,9 @@ def load_weights_from_checkpoint( | |||
) | |||
|
|||
# pl_forecasting module saves the train_sample shape, must recreate one | |||
np_dtype = np.float32 if ckpt["model_dtype"] == torch.float32 else np.float64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest creating a mapping between the torch and numpy dtypes at the top of this file. With this we can avoid/catch some breaking changes (e.g. if we ever support other dtypes such as float16 and forget about this line).
TORCH_NP_DTYPES = {
torch.float32: np.float32,
torch.float64: np.float64,
}
...
mock_train_sample = [
np.zeros(sample_shape, dtype=TORCH_NP_DTYPES[ckpt["model_dtype"]]) if sample_shape else None
for sample_shape in ckpt["train_sample_shape"]
]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also add float16
TORCH_NP_DTYPES = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
}
Just for the record, this PR fixed one of the point mentioned in #1987. |
Codecov ReportAttention: ❗ Your organization needs to install the Codecov GitHub app to enable full functionality.
... and 2 files with indirect coverage changes 📢 Thoughts on this report? Let us know!. |
d09ac98
to
0ad713e
Compare
Hi @Hsinfu , it seems like your branch is protected for us to adapt it. Can you merge the current master head into your branch and add the following line to our
|
0ad713e
to
b372357
Compare
b372357
to
4bf0ef2
Compare
I just cloned the current master head, and rebase my branch on it. Also add the information in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot for this @Hsinfu 🚀 💯
Fix bug for TorchForecastingModel calling load_weights with float16 and float32. Fixes #1987.
Bug details
.save
it to disk..load_weights
to restore it.mock_train_sample
in.load_weights_from_checkpoint
built fromnp.zero
without specifying thedtype
, so it sticks to the default float64, and causes theself.train_sample
stays float64, too.