From 4d7c276fd5a6438bc4f43501715420ab11a01680 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Sep 2023 11:28:48 -0400 Subject: [PATCH] Clean --- src/accelerate/accelerator.py | 7 ++++++- src/accelerate/utils/other.py | 12 +++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 369a4786160..1222b6ac40d 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2447,7 +2447,12 @@ def save(self, obj, f, safe_serialization=False): >>> accelerator.save(arr, "array.pkl") ``` """ - save(obj, f, save_on_each_node=self.project_configuration.save_on_each_node, safe_serialization=safe_serialization) + save( + obj, + f, + save_on_each_node=self.project_configuration.save_on_each_node, + safe_serialization=safe_serialization, + ) def save_model( self, diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index d87b688bdf3..42f9a6cc759 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -109,18 +109,19 @@ def wait_for_everyone(): """ PartialState().wait_for_everyone() -def save(obj, f, save_on_each_node: bool = False, safe_serialization:bool = False): + +def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): """ Save the data to disk. Use in place of `torch.save()`. Args: - obj: + obj: The data to save - f: + f: The file (or file-like object) to use to save the data - save_on_each_node (`bool`, *optional*, defaults to `False`): + save_on_each_node (`bool`, *optional*, defaults to `False`): Whether to only save on the global main process - safe_serialization (`bool`, *optional*, defaults to `False`): + safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors` """ save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"}) @@ -131,6 +132,7 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization:bool = Fals elif PartialState().is_local_main_process and not save_on_each_node: save_func(obj, f) + @contextmanager def clear_environment(): """