From f1d4539d8028429374f043eaa7859cd9fa14a5a3 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Sep 2023 17:53:43 +0000 Subject: [PATCH 01/12] all_gather_into_tensor --- src/accelerate/utils/operations.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 80ebe51d07c..3e12fdba038 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -287,9 +287,18 @@ def _gpu_gather_one(tensor): # Can only gather contiguous tensors if not tensor.is_contiguous(): tensor = tensor.contiguous() - output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensor) - return torch.cat(output_tensors, dim=0) + + state = PartialState() + output_tensors = torch.zeros( + state.num_processes * tensor.numel(), + dtype=tensor.dtype, + device=state.device, + ) + torch.distributed.all_gather_into_tensor(output_tensors, tensor) + return output_tensors.view(-1, *tensor.size()[1:]) + # output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())] + # torch.distributed.all_gather(output_tensors, tensor) + # return torch.cat(output_tensors, dim=0) return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) From 82a34eae4e5ae847ef04d76c2d601c3cfc81365c Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Sep 2023 17:55:53 +0000 Subject: [PATCH 02/12] Cleanup --- src/accelerate/utils/operations.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 3e12fdba038..642dcbe7cca 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -296,9 +296,6 @@ def _gpu_gather_one(tensor): ) torch.distributed.all_gather_into_tensor(output_tensors, tensor) return output_tensors.view(-1, *tensor.size()[1:]) - # output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())] - # torch.distributed.all_gather(output_tensors, tensor) - # return torch.cat(output_tensors, dim=0) return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) From d484988a1ce308cceb654d59275a908e16580c39 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Sep 2023 18:09:28 +0000 Subject: [PATCH 03/12] Reduce memory on non-gloo --- src/accelerate/utils/operations.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 642dcbe7cca..4b96df7f1d0 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -289,13 +289,21 @@ def _gpu_gather_one(tensor): tensor = tensor.contiguous() state = PartialState() - output_tensors = torch.zeros( - state.num_processes * tensor.numel(), - dtype=tensor.dtype, - device=state.device, - ) - torch.distributed.all_gather_into_tensor(output_tensors, tensor) - return output_tensors.view(-1, *tensor.size()[1:]) + + if state.backend != "gloo": + output_tensors = torch.zeros( + state.num_processes * tensor.numel(), + dtype=tensor.dtype, + device=state.device, + ) + torch.distributed.all_gather_into_tensor(output_tensors, tensor) + return output_tensors.view(-1, *tensor.size()[1:]) + else: + # gloo does not support `all_gather_into_tensor`, which will result + # in a larger memory overhead for the op + output_tensors = [torch.zeros_like(tensor) for _ in range(state.num_processes)] + torch.distributed.all_gather(output_tensors, tensor) + return torch.cat(output_tensors, dim=0) return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) From fd35a7a266b556ed01a292b71a1f224cdd751e7f Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Sep 2023 18:17:19 +0000 Subject: [PATCH 04/12] Fin --- src/accelerate/utils/operations.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 4b96df7f1d0..7e4d371d262 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -25,7 +25,7 @@ from ..state import PartialState from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES from .dataclasses import DistributedType, TensorInformation -from .imports import is_torch_distributed_available, is_tpu_available +from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available if is_tpu_available(check_device=False): @@ -296,7 +296,10 @@ def _gpu_gather_one(tensor): dtype=tensor.dtype, device=state.device, ) - torch.distributed.all_gather_into_tensor(output_tensors, tensor) + if is_torch_version(">=", "1.13"): + torch.distributed.all_gather_into_tensor(output_tensors, tensor) + else: + torch.distributed._all_gather_base(output_tensors, tensor) return output_tensors.view(-1, *tensor.size()[1:]) else: # gloo does not support `all_gather_into_tensor`, which will result From 0a3bd1c847d7cb640f4ccd5bf31cf0d72955ce71 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Sep 2023 15:52:57 -0400 Subject: [PATCH 05/12] Check for backend too on cpu --- src/accelerate/utils/operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 7e4d371d262..9bb67212298 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -290,7 +290,7 @@ def _gpu_gather_one(tensor): state = PartialState() - if state.backend != "gloo": + if state.backend is not None and state.backend != "gloo": output_tensors = torch.zeros( state.num_processes * tensor.numel(), dtype=tensor.dtype, From 7a36532f03dc5a0dc6f0b3a158bdcf760d42c129 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 2 Oct 2023 11:43:29 -0400 Subject: [PATCH 06/12] CPU comment --- src/accelerate/utils/operations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 9bb67212298..06f23e3848a 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -289,7 +289,6 @@ def _gpu_gather_one(tensor): tensor = tensor.contiguous() state = PartialState() - if state.backend is not None and state.backend != "gloo": output_tensors = torch.zeros( state.num_processes * tensor.numel(), @@ -302,8 +301,9 @@ def _gpu_gather_one(tensor): torch.distributed._all_gather_base(output_tensors, tensor) return output_tensors.view(-1, *tensor.size()[1:]) else: - # gloo does not support `all_gather_into_tensor`, which will result - # in a larger memory overhead for the op + # a backend of `None` is always CPU + # also gloo does not support `all_gather_into_tensor`, + # which will result in a larger memory overhead for the op output_tensors = [torch.zeros_like(tensor) for _ in range(state.num_processes)] torch.distributed.all_gather(output_tensors, tensor) return torch.cat(output_tensors, dim=0) From 1977e5c59744f6f0897b5eb6eeb3d019309d9570 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 2 Oct 2023 11:44:35 -0400 Subject: [PATCH 07/12] Change scope for performance --- src/accelerate/utils/operations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 06f23e3848a..bf8ac126b24 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -280,6 +280,8 @@ def _tpu_gather_one(tensor): def _gpu_gather(tensor): + state = PartialState() + def _gpu_gather_one(tensor): if tensor.ndim == 0: tensor = tensor.clone()[None] @@ -288,7 +290,6 @@ def _gpu_gather_one(tensor): if not tensor.is_contiguous(): tensor = tensor.contiguous() - state = PartialState() if state.backend is not None and state.backend != "gloo": output_tensors = torch.zeros( state.num_processes * tensor.numel(), From a425a8933528887d2f5c16966efed92d4ceb7c1c Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 2 Oct 2023 12:12:26 -0400 Subject: [PATCH 08/12] Bring back zeros after remembering why --- src/accelerate/utils/operations.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index bf8ac126b24..63101c65cf0 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -281,6 +281,10 @@ def _tpu_gather_one(tensor): def _gpu_gather(tensor): state = PartialState() + if is_torch_version(">=", "1.13"): + gather_op = torch.distributed.all_gather_into_tensor + else: + gather_op = torch.distributed._all_gather_base def _gpu_gather_one(tensor): if tensor.ndim == 0: @@ -296,16 +300,13 @@ def _gpu_gather_one(tensor): dtype=tensor.dtype, device=state.device, ) - if is_torch_version(">=", "1.13"): - torch.distributed.all_gather_into_tensor(output_tensors, tensor) - else: - torch.distributed._all_gather_base(output_tensors, tensor) + gather_op(output_tensors, tensor) return output_tensors.view(-1, *tensor.size()[1:]) else: # a backend of `None` is always CPU # also gloo does not support `all_gather_into_tensor`, # which will result in a larger memory overhead for the op - output_tensors = [torch.zeros_like(tensor) for _ in range(state.num_processes)] + output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)] torch.distributed.all_gather(output_tensors, tensor) return torch.cat(output_tensors, dim=0) From cd067f5fb2a3b89c90c950db225f4179d16f056b Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 2 Oct 2023 12:13:14 -0400 Subject: [PATCH 09/12] Add comment --- src/accelerate/utils/operations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 63101c65cf0..fd5753968f8 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -295,6 +295,8 @@ def _gpu_gather_one(tensor): tensor = tensor.contiguous() if state.backend is not None and state.backend != "gloo": + # We use `zeros` as `all_gather_into_tensor` slightly + # differs from `all_gather` for better efficiency output_tensors = torch.zeros( state.num_processes * tensor.numel(), dtype=tensor.dtype, From 5449149885383efe1db3b20348ae49ee8f0a4047 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 2 Oct 2023 12:13:48 -0400 Subject: [PATCH 10/12] Add comment --- src/accelerate/utils/operations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index fd5753968f8..4e527ced2a7 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -296,7 +296,9 @@ def _gpu_gather_one(tensor): if state.backend is not None and state.backend != "gloo": # We use `zeros` as `all_gather_into_tensor` slightly - # differs from `all_gather` for better efficiency + # differs from `all_gather` for better efficiency, + # and we rely on the number of items in the tensor + # rather than its direct shape output_tensors = torch.zeros( state.num_processes * tensor.numel(), dtype=tensor.dtype, From 5306fc149768186d6f91254938c11b7f82b87e02 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 2 Oct 2023 18:23:55 +0000 Subject: [PATCH 11/12] Use empty --- src/accelerate/utils/operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 4e527ced2a7..991f2a81200 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -299,7 +299,7 @@ def _gpu_gather_one(tensor): # differs from `all_gather` for better efficiency, # and we rely on the number of items in the tensor # rather than its direct shape - output_tensors = torch.zeros( + output_tensors = torch.empty( state.num_processes * tensor.numel(), dtype=tensor.dtype, device=state.device, From cd70333abca97bf9d8fc84a957111e673b02036e Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 2 Oct 2023 18:25:11 +0000 Subject: [PATCH 12/12] Comment --- src/accelerate/utils/operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 991f2a81200..ee91b60c4eb 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -295,7 +295,7 @@ def _gpu_gather_one(tensor): tensor = tensor.contiguous() if state.backend is not None and state.backend != "gloo": - # We use `zeros` as `all_gather_into_tensor` slightly + # We use `empty` as `all_gather_into_tensor` slightly # differs from `all_gather` for better efficiency, # and we rely on the number of items in the tensor # rather than its direct shape