Skip to content

Commit

Permalink
Fix hang
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 10, 2024
1 parent 4328e5f commit ab37917
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,15 @@ def load(f, map_location=None, **kwargs):
old_safe_globals = torch.serialization.get_safe_globals()
if "weights_only" not in kwargs:
kwargs["weights_only"] = True
torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS + old_safe_globals)
torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS)
else:
kwargs.pop("weights_only", None)
loaded_obj = torch.load(f, map_location=map_location, **kwargs)
finally:
if is_weights_only_available():
torch.serialization.clear_safe_globals()
torch.serialization.add_safe_globals(old_safe_globals)
if old_safe_globals:
torch.serialization.add_safe_globals(old_safe_globals)
return loaded_obj


Expand Down

0 comments on commit ab37917

Please sign in to comment.