From 2129036d2c1c330fbbbee37e02d11a0a09121e05 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Wed, 20 Sep 2023 18:55:11 +0200 Subject: [PATCH 1/3] fix distributed devices --- src/axolotl/utils/distributed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 10f06538ab..bc9f5d7caf 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -77,7 +77,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n value_scalar = fn() if not is_distributed(): return [value_scalar] - value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + value_tensor = torch.tensor(value_scalar, device=dist.get_rank() % torch.cuda.device_count()).float() if not is_main_process(): dist.gather(value_tensor, dst=0) @@ -137,9 +137,9 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name """ if is_main_process(): value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + value_tensor = torch.tensor(value_scalar, device=dist.get_rank() % torch.cuda.device_count()).float() else: - value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor + value_tensor = torch.tensor(0.0, device=dist.get_rank() % torch.cuda.device_count()) # Placeholder tensor # Broadcast the tensor to all processes. barrier() @@ -164,7 +164,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name - A list of computed values from all ranks if on the gathering rank, otherwise None. """ value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + value_tensor = torch.tensor(value_scalar, device=dist.get_rank() % torch.cuda.device_count()).float() # Placeholder tensor for gathering results if is_main_process(): From b6706943f016b4ceaaff30ecd7190277ad4598fa Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Wed, 20 Sep 2023 19:26:50 +0200 Subject: [PATCH 2/3] Update distributed.py --- src/axolotl/utils/distributed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index bc9f5d7caf..66e916ae86 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -77,7 +77,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n value_scalar = fn() if not is_distributed(): return [value_scalar] - value_tensor = torch.tensor(value_scalar, device=dist.get_rank() % torch.cuda.device_count()).float() + value_tensor = torch.tensor(value_scalar, device=torch.cuda.current_device()).float() if not is_main_process(): dist.gather(value_tensor, dst=0) @@ -137,9 +137,9 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name """ if is_main_process(): value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=dist.get_rank() % torch.cuda.device_count()).float() + value_tensor = torch.tensor(value_scalar, device=torch.cuda.current_device()).float() else: - value_tensor = torch.tensor(0.0, device=dist.get_rank() % torch.cuda.device_count()) # Placeholder tensor + value_tensor = torch.tensor(0.0, device=torch.cuda.current_device()) # Placeholder tensor # Broadcast the tensor to all processes. barrier() @@ -164,7 +164,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name - A list of computed values from all ranks if on the gathering rank, otherwise None. """ value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=dist.get_rank() % torch.cuda.device_count()).float() + value_tensor = torch.tensor(value_scalar, device=torch.cuda.current_device()).float() # Placeholder tensor for gathering results if is_main_process(): From a75a3bc8c972e25ae9dd0590653a120de1ff0a58 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Wed, 20 Sep 2023 20:02:44 +0200 Subject: [PATCH 3/3] Update distributed.py --- src/axolotl/utils/distributed.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 66e916ae86..9a1c689fb7 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n value_scalar = fn() if not is_distributed(): return [value_scalar] - value_tensor = torch.tensor(value_scalar, device=torch.cuda.current_device()).float() + value_tensor = torch.tensor( + value_scalar, device=torch.cuda.current_device() + ).float() if not is_main_process(): dist.gather(value_tensor, dst=0) @@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name """ if is_main_process(): value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=torch.cuda.current_device()).float() + value_tensor = torch.tensor( + value_scalar, device=torch.cuda.current_device() + ).float() else: - value_tensor = torch.tensor(0.0, device=torch.cuda.current_device()) # Placeholder tensor + value_tensor = torch.tensor( + 0.0, device=torch.cuda.current_device() + ) # Placeholder tensor # Broadcast the tensor to all processes. barrier() @@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name - A list of computed values from all ranks if on the gathering rank, otherwise None. """ value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=torch.cuda.current_device()).float() + value_tensor = torch.tensor( + value_scalar, device=torch.cuda.current_device() + ).float() # Placeholder tensor for gathering results if is_main_process():