diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 80ebe51d07c..ee91b60c4eb 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -25,7 +25,7 @@ from ..state import PartialState from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES from .dataclasses import DistributedType, TensorInformation -from .imports import is_torch_distributed_available, is_tpu_available +from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available if is_tpu_available(check_device=False): @@ -280,6 +280,12 @@ def _tpu_gather_one(tensor): def _gpu_gather(tensor): + state = PartialState() + if is_torch_version(">=", "1.13"): + gather_op = torch.distributed.all_gather_into_tensor + else: + gather_op = torch.distributed._all_gather_base + def _gpu_gather_one(tensor): if tensor.ndim == 0: tensor = tensor.clone()[None] @@ -287,9 +293,26 @@ def _gpu_gather_one(tensor): # Can only gather contiguous tensors if not tensor.is_contiguous(): tensor = tensor.contiguous() - output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensor) - return torch.cat(output_tensors, dim=0) + + if state.backend is not None and state.backend != "gloo": + # We use `empty` as `all_gather_into_tensor` slightly + # differs from `all_gather` for better efficiency, + # and we rely on the number of items in the tensor + # rather than its direct shape + output_tensors = torch.empty( + state.num_processes * tensor.numel(), + dtype=tensor.dtype, + device=state.device, + ) + gather_op(output_tensors, tensor) + return output_tensors.view(-1, *tensor.size()[1:]) + else: + # a backend of `None` is always CPU + # also gloo does not support `all_gather_into_tensor`, + # which will result in a larger memory overhead for the op + output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)] + torch.distributed.all_gather(output_tensors, tensor) + return torch.cat(output_tensors, dim=0) return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)