diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 026e16166e1..b55b83565d6 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -152,6 +152,7 @@ def clean_state_dict_for_safetensors(state_dict: dict): state_dict = {k: v.contiguous() for k, v in state_dict.items()} return state_dict + def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): """ Save the data to disk. Use in place of `torch.save()`.