From 0ad713ec2f57d813f5d8e0baf3ec76278b3ec2e5 Mon Sep 17 00:00:00 2001 From: hsinfu Date: Mon, 6 Nov 2023 14:11:52 +0800 Subject: [PATCH] fix: torch_forecasting_model load_weights with float16 and float32 --- .../forecasting/torch_forecasting_model.py | 9 +++++++- .../test_torch_forecasting_model.py | 23 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index b6643c1db9..6f6dca83ac 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -108,6 +108,12 @@ RUNS_FOLDER = "runs" INIT_MODEL_NAME = "_model.pth.tar" +TORCH_NP_DTYPES = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, +} + # pickling a TorchForecastingModel will not save below attributes: the keys specify the # attributes to be ignored, and the values are the default values getting assigned upon loading TFM_ATTRS_NO_PICKLE = {"model": None, "trainer": None} @@ -1872,8 +1878,9 @@ def load_weights_from_checkpoint( ) # pl_forecasting module saves the train_sample shape, must recreate one + np_dtype = TORCH_NP_DTYPES[ckpt["model_dtype"]] mock_train_sample = [ - np.zeros(sample_shape) if sample_shape else None + np.zeros(sample_shape, dtype=np_dtype) if sample_shape else None for sample_shape in ckpt["train_sample_shape"] ] self.train_sample = tuple(mock_train_sample) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 400899b76a..6a150bbfa3 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1045,6 +1045,29 @@ def test_load_weights(self, tmpdir_fn): f"respectively {retrained_mape} and {original_mape}" ) + def test_load_weights_with_float32_dtype(self, tmpdir_fn): + ts_float32 = self.series.astype("float32") + model_name = "test_model" + ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt") + # barebone model + model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + n_epochs=1, + ) + model.fit(ts_float32) + model.save(ckpt_path) + assert model.model._dtype == torch.float32 # type: ignore + + # identical model + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + ) + loading_model.load_weights(ckpt_path) + loading_model.fit(ts_float32) + assert loading_model.model._dtype == torch.float32 # type: ignore + def test_multi_steps_pipeline(self, tmpdir_fn): ts_training, ts_val = self.series.split_before(75) pretrain_model_name = "pre-train"