diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 42f9a6cc759..6c273d23b79 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -127,9 +127,9 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"}) if PartialState().distributed_type == DistributedType.TPU: xm.save(obj, f) - elif PartialState().is_main_process and save_on_each_node: + elif PartialState().is_main_process and not save_on_each_node: save_func(obj, f) - elif PartialState().is_local_main_process and not save_on_each_node: + elif PartialState().is_local_main_process and save_on_each_node: save_func(obj, f)