diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 9bb67212298..06f23e3848a 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -289,7 +289,6 @@ def _gpu_gather_one(tensor): tensor = tensor.contiguous() state = PartialState() - if state.backend is not None and state.backend != "gloo": output_tensors = torch.zeros( state.num_processes * tensor.numel(), @@ -302,8 +301,9 @@ def _gpu_gather_one(tensor): torch.distributed._all_gather_base(output_tensors, tensor) return output_tensors.view(-1, *tensor.size()[1:]) else: - # gloo does not support `all_gather_into_tensor`, which will result - # in a larger memory overhead for the op + # a backend of `None` is always CPU + # also gloo does not support `all_gather_into_tensor`, + # which will result in a larger memory overhead for the op output_tensors = [torch.zeros_like(tensor) for _ in range(state.num_processes)] torch.distributed.all_gather(output_tensors, tensor) return torch.cat(output_tensors, dim=0)