Skip to content

Commit

Permalink
Reduce memory on non-gloo
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Sep 13, 2023
1 parent 82a34ea commit d484988
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,21 @@ def _gpu_gather_one(tensor):
tensor = tensor.contiguous()

state = PartialState()
output_tensors = torch.zeros(
state.num_processes * tensor.numel(),
dtype=tensor.dtype,
device=state.device,
)
torch.distributed.all_gather_into_tensor(output_tensors, tensor)
return output_tensors.view(-1, *tensor.size()[1:])

if state.backend != "gloo":
output_tensors = torch.zeros(
state.num_processes * tensor.numel(),
dtype=tensor.dtype,
device=state.device,
)
torch.distributed.all_gather_into_tensor(output_tensors, tensor)
return output_tensors.view(-1, *tensor.size()[1:])
else:
# gloo does not support `all_gather_into_tensor`, which will result
# in a larger memory overhead for the op
output_tensors = [torch.zeros_like(tensor) for _ in range(state.num_processes)]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)

return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)

Expand Down

0 comments on commit d484988

Please sign in to comment.