diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 4b96df7f1d0..7e4d371d262 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -25,7 +25,7 @@ from ..state import PartialState from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES from .dataclasses import DistributedType, TensorInformation -from .imports import is_torch_distributed_available, is_tpu_available +from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available if is_tpu_available(check_device=False): @@ -296,7 +296,10 @@ def _gpu_gather_one(tensor): dtype=tensor.dtype, device=state.device, ) - torch.distributed.all_gather_into_tensor(output_tensors, tensor) + if is_torch_version(">=", "1.13"): + torch.distributed.all_gather_into_tensor(output_tensors, tensor) + else: + torch.distributed._all_gather_base(output_tensors, tensor) return output_tensors.view(-1, *tensor.size()[1:]) else: # gloo does not support `all_gather_into_tensor`, which will result