Skip to content

Commit

Permalink
fix: torch_forecasting_model load_weights with float32
Browse files Browse the repository at this point in the history
  • Loading branch information
Hsinfu committed Nov 1, 2023
1 parent ea37dc9 commit d09ac98
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,8 +1872,9 @@ def load_weights_from_checkpoint(
)

# pl_forecasting module saves the train_sample shape, must recreate one
np_dtype = np.float32 if ckpt["model_dtype"] == torch.float32 else np.float64
mock_train_sample = [
np.zeros(sample_shape) if sample_shape else None
np.zeros(sample_shape, dtype=np_dtype) if sample_shape else None
for sample_shape in ckpt["train_sample_shape"]
]
self.train_sample = tuple(mock_train_sample)
Expand Down

0 comments on commit d09ac98

Please sign in to comment.