Skip to content

Commit

Permalink
Reduce memory by using all_gather_into_tensor (#1968)
Browse files Browse the repository at this point in the history
* all_gather_into_tensor

* Cleanup

* Reduce memory on non-gloo

* Fin

* Check for backend too on cpu

* CPU comment

* Change scope for performance

* Bring back zeros after remembering why

* Add comment

* Add comment

* Use empty

* Comment
  • Loading branch information
muellerzr authored Oct 10, 2023
1 parent 7a11591 commit 73640d0
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..state import PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import is_torch_distributed_available, is_tpu_available
from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available


if is_tpu_available(check_device=False):
Expand Down Expand Up @@ -280,16 +280,39 @@ def _tpu_gather_one(tensor):


def _gpu_gather(tensor):
state = PartialState()
if is_torch_version(">=", "1.13"):
gather_op = torch.distributed.all_gather_into_tensor
else:
gather_op = torch.distributed._all_gather_base

def _gpu_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]

# Can only gather contiguous tensors
if not tensor.is_contiguous():
tensor = tensor.contiguous()
output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)

if state.backend is not None and state.backend != "gloo":
# We use `empty` as `all_gather_into_tensor` slightly
# differs from `all_gather` for better efficiency,
# and we rely on the number of items in the tensor
# rather than its direct shape
output_tensors = torch.empty(
state.num_processes * tensor.numel(),
dtype=tensor.dtype,
device=state.device,
)
gather_op(output_tensors, tensor)
return output_tensors.view(-1, *tensor.size()[1:])
else:
# a backend of `None` is always CPU
# also gloo does not support `all_gather_into_tensor`,
# which will result in a larger memory overhead for the op
output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)

return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)

Expand Down

0 comments on commit 73640d0

Please sign in to comment.