diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 10f06538ab..9a1c689fb7 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n value_scalar = fn() if not is_distributed(): return [value_scalar] - value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + value_tensor = torch.tensor( + value_scalar, device=torch.cuda.current_device() + ).float() if not is_main_process(): dist.gather(value_tensor, dst=0) @@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name """ if is_main_process(): value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + value_tensor = torch.tensor( + value_scalar, device=torch.cuda.current_device() + ).float() else: - value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor + value_tensor = torch.tensor( + 0.0, device=torch.cuda.current_device() + ) # Placeholder tensor # Broadcast the tensor to all processes. barrier() @@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name - A list of computed values from all ranks if on the gathering rank, otherwise None. """ value_scalar = fn() - value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + value_tensor = torch.tensor( + value_scalar, device=torch.cuda.current_device() + ).float() # Placeholder tensor for gathering results if is_main_process():