diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index b321e0e0..c745b9d1 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -121,9 +121,8 @@ def all_reduce( ): return object - # wrap object in a tensor if it is not already - if not isinstance(object, torch.Tensor): - object = torch.Tensor(object) + # wrap object in a tensor + tensor = torch.Tensor(object) # determine operation if operation == "sum": @@ -133,8 +132,8 @@ def all_reduce( else: raise ValueError(f"all_reduce unsupported operation: {operation}") - torch.distributed.all_reduce(object.cuda(), op=torch_op) - return object.cpu().tolist() + torch.distributed.all_reduce(tensor, op=torch_op) + return tensor.cpu().tolist() if jax_is_available(): # Check if not distributed @@ -155,3 +154,4 @@ def is_distributed() -> bool: if jax_is_available(): jax = MODULE_CACHE["jax"] return jax.process_count() > 1 + raise RuntimeError("No framework is available.")