Skip to content

Commit

Permalink
fix: torch_forecasting_model load_weights with float16 and float32
Browse files Browse the repository at this point in the history
  • Loading branch information
Hsinfu committed Nov 6, 2023
1 parent ea37dc9 commit 0ad713e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
9 changes: 8 additions & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@
RUNS_FOLDER = "runs"
INIT_MODEL_NAME = "_model.pth.tar"

TORCH_NP_DTYPES = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
}

# pickling a TorchForecastingModel will not save below attributes: the keys specify the
# attributes to be ignored, and the values are the default values getting assigned upon loading
TFM_ATTRS_NO_PICKLE = {"model": None, "trainer": None}
Expand Down Expand Up @@ -1872,8 +1878,9 @@ def load_weights_from_checkpoint(
)

# pl_forecasting module saves the train_sample shape, must recreate one
np_dtype = TORCH_NP_DTYPES[ckpt["model_dtype"]]
mock_train_sample = [
np.zeros(sample_shape) if sample_shape else None
np.zeros(sample_shape, dtype=np_dtype) if sample_shape else None
for sample_shape in ckpt["train_sample_shape"]
]
self.train_sample = tuple(mock_train_sample)
Expand Down
23 changes: 23 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,29 @@ def test_load_weights(self, tmpdir_fn):
f"respectively {retrained_mape} and {original_mape}"
)

def test_load_weights_with_float32_dtype(self, tmpdir_fn):
ts_float32 = self.series.astype("float32")
model_name = "test_model"
ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt")
# barebone model
model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
n_epochs=1,
)
model.fit(ts_float32)
model.save(ckpt_path)
assert model.model._dtype == torch.float32 # type: ignore

# identical model
loading_model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
)
loading_model.load_weights(ckpt_path)
loading_model.fit(ts_float32)
assert loading_model.model._dtype == torch.float32 # type: ignore

def test_multi_steps_pipeline(self, tmpdir_fn):
ts_training, ts_val = self.series.split_before(75)
pretrain_model_name = "pre-train"
Expand Down

0 comments on commit 0ad713e

Please sign in to comment.