Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
parthraut committed Dec 13, 2024
1 parent ea0f866 commit 08fd4d3
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ def all_reduce(
):
return object

# wrap object in a tensor if it is not already
if not isinstance(object, torch.Tensor):
object = torch.Tensor(object)
# wrap object in a tensor
tensor = torch.Tensor(object)

# determine operation
if operation == "sum":
Expand All @@ -133,8 +132,8 @@ def all_reduce(
else:
raise ValueError(f"all_reduce unsupported operation: {operation}")

torch.distributed.all_reduce(object.cuda(), op=torch_op)
return object.cpu().tolist()
torch.distributed.all_reduce(tensor, op=torch_op)
return tensor.cpu().tolist()

if jax_is_available():
# Check if not distributed
Expand All @@ -155,3 +154,4 @@ def is_distributed() -> bool:
if jax_is_available():
jax = MODULE_CACHE["jax"]
return jax.process_count() > 1
raise RuntimeError("No framework is available.")

0 comments on commit 08fd4d3

Please sign in to comment.