diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 7a58623..ddb69ec 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -605,7 +605,7 @@ def load( states_path = Path(path) / f"{which}.pt" - checkpoint = torch.load(states_path, map_location=map_location) + checkpoint = torch.load(states_path, map_location=map_location, weights_only=True) model_state = self.cleanup_state_dict_keys(checkpoint["state_dict"]) self.load_state_dict(model_state)