Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: torch_forecasting_model load_weights with float16 and float32 #2046

Conversation

Hsinfu
Copy link
Contributor

@Hsinfu Hsinfu commented Nov 1, 2023

Fix bug for TorchForecastingModel calling load_weights with float16 and float32. Fixes #1987.

Bug details

  • Train the TorchForecastingModel with float32 following the instruction and .save it to disk.
  • Create a new model instance and call .load_weights to restore it.
  • The mock_train_sample in .load_weights_from_checkpoint built from np.zero without specifying the dtype, so it sticks to the default float64, and causes the self.train_sample stays float64, too.

@Hsinfu Hsinfu requested a review from dennisbader as a code owner November 1, 2023 09:37
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Hsinfu for contributing and this fix 🚀

I had 2 minor suggestions regarding adding a unit test and making the dtype mapping more robust against future changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try adding a unit test for this?

The best place for this would be in darts/tests/models/forecasting/test_torch_forecasting_model.
Maybe something like test_loading_weights_with_different_dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_load_weights_with_float32_dtype

@@ -1872,8 +1872,9 @@ def load_weights_from_checkpoint(
)

# pl_forecasting module saves the train_sample shape, must recreate one
np_dtype = np.float32 if ckpt["model_dtype"] == torch.float32 else np.float64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest creating a mapping between the torch and numpy dtypes at the top of this file. With this we can avoid/catch some breaking changes (e.g. if we ever support other dtypes such as float16 and forget about this line).

TORCH_NP_DTYPES = {
    torch.float32: np.float32,
    torch.float64: np.float64,
}

...

mock_train_sample = [
    np.zeros(sample_shape, dtype=TORCH_NP_DTYPES[ckpt["model_dtype"]]) if sample_shape else None
    for sample_shape in ckpt["train_sample_shape"]
]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add float16

TORCH_NP_DTYPES = {
    torch.float16: np.float16,
    torch.float32: np.float32,
    torch.float64: np.float64,
}

@madtoinou
Copy link
Collaborator

Just for the record, this PR fixed one of the point mentioned in #1987.

@codecov-commenter
Copy link

codecov-commenter commented Nov 2, 2023

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Files Coverage Δ
darts/dataprocessing/encoders/encoders.py 97.68% <100.00%> (+0.07%) ⬆️
darts/models/forecasting/arima.py 93.75% <ø> (ø)
darts/models/forecasting/auto_arima.py 96.77% <ø> (ø)
darts/models/forecasting/block_rnn_model.py 98.27% <ø> (-0.03%) ⬇️
darts/models/forecasting/catboost_model.py 98.41% <ø> (ø)
darts/models/forecasting/croston.py 90.69% <ø> (ø)
darts/models/forecasting/dlinear.py 99.12% <ø> (ø)
darts/models/forecasting/kalman_forecaster.py 97.36% <ø> (ø)
darts/models/forecasting/lgbm.py 100.00% <ø> (ø)
...arts/models/forecasting/linear_regression_model.py 94.44% <ø> (ø)
... and 18 more

... and 2 files with indirect coverage changes

📢 Thoughts on this report? Let us know!.

@Hsinfu Hsinfu force-pushed the fix/torch_forecasting_model_load_weights_with_float32 branch from d09ac98 to 0ad713e Compare November 6, 2023 06:12
@Hsinfu Hsinfu changed the title fix: torch_forecasting_model load_weights with float32 fix: torch_forecasting_model load_weights with float16 and float32 Nov 6, 2023
@dennisbader
Copy link
Collaborator

Hi @Hsinfu , it seems like your branch is protected for us to adapt it.

Can you merge the current master head into your branch and add the following line to our CHANGELOG.md Fixed section?

- Fixed a bug when loading a `TorchForecastingModel` that was trained with a precision other than `float64`. [#2046](https://github.com/unit8co/darts/pull/2046) by [Freddie Hsin-Fu Huang](https://github.com/Hsinfu).

@Hsinfu Hsinfu force-pushed the fix/torch_forecasting_model_load_weights_with_float32 branch from 0ad713e to b372357 Compare November 6, 2023 10:18
@Hsinfu Hsinfu force-pushed the fix/torch_forecasting_model_load_weights_with_float32 branch from b372357 to 4bf0ef2 Compare November 6, 2023 10:23
@Hsinfu
Copy link
Contributor Author

Hsinfu commented Nov 6, 2023

  • Fixed a bug when loading a TorchForecastingModel that was trained with a precision other than float64. #2046 by Freddie Hsin-Fu Huang.

I just cloned the current master head, and rebase my branch on it. Also add the information in CHANGELOG.md.
@dennisbader See if it works now. Thank you~

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot for this @Hsinfu 🚀 💯

@dennisbader dennisbader merged commit af5b141 into unit8co:master Nov 6, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TorchForecastingModel model precision
4 participants