From d484988a1ce308cceb654d59275a908e16580c39 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Sep 2023 18:09:28 +0000 Subject: [PATCH] Reduce memory on non-gloo --- src/accelerate/utils/operations.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 642dcbe7cca..4b96df7f1d0 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -289,13 +289,21 @@ def _gpu_gather_one(tensor): tensor = tensor.contiguous() state = PartialState() - output_tensors = torch.zeros( - state.num_processes * tensor.numel(), - dtype=tensor.dtype, - device=state.device, - ) - torch.distributed.all_gather_into_tensor(output_tensors, tensor) - return output_tensors.view(-1, *tensor.size()[1:]) + + if state.backend != "gloo": + output_tensors = torch.zeros( + state.num_processes * tensor.numel(), + dtype=tensor.dtype, + device=state.device, + ) + torch.distributed.all_gather_into_tensor(output_tensors, tensor) + return output_tensors.view(-1, *tensor.size()[1:]) + else: + # gloo does not support `all_gather_into_tensor`, which will result + # in a larger memory overhead for the op + output_tensors = [torch.zeros_like(tensor) for _ in range(state.num_processes)] + torch.distributed.all_gather(output_tensors, tensor) + return torch.cat(output_tensors, dim=0) return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)