Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory by using all_gather_into_tensor #1968

Merged
merged 12 commits into from
Oct 10, 2023
31 changes: 27 additions & 4 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..state import PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import is_torch_distributed_available, is_tpu_available
from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available


if is_tpu_available(check_device=False):
Expand Down Expand Up @@ -280,16 +280,39 @@ def _tpu_gather_one(tensor):


def _gpu_gather(tensor):
state = PartialState()
if is_torch_version(">=", "1.13"):
gather_op = torch.distributed.all_gather_into_tensor
else:
gather_op = torch.distributed._all_gather_base

def _gpu_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]

# Can only gather contiguous tensors
if not tensor.is_contiguous():
tensor = tensor.contiguous()
output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)

if state.backend is not None and state.backend != "gloo":
# We use `zeros` as `all_gather_into_tensor` slightly
# differs from `all_gather` for better efficiency,
# and we rely on the number of items in the tensor
# rather than its direct shape
output_tensors = torch.zeros(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use torch.zeros instead of torch.empty_like as previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're doing something slightly different here with the new API, where this gather works using a different tensor dim than before which is more efficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but won't this allocate more memory than previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually the opposite, hence what we're doing here.

state.num_processes * tensor.numel(),
dtype=tensor.dtype,
device=state.device,
)
gather_op(output_tensors, tensor)
return output_tensors.view(-1, *tensor.size()[1:])
else:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
# a backend of `None` is always CPU
# also gloo does not support `all_gather_into_tensor`,
# which will result in a larger memory overhead for the op
output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)

return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)

Expand Down
Loading