From 2f415967fb36d923280e0ef9040ce203a4ae8377 Mon Sep 17 00:00:00 2001 From: Parth Raut Date: Thu, 12 Dec 2024 21:33:07 -0500 Subject: [PATCH] fixes --- examples/power_limit_optimizer/train_dp.py | 28 +++++++++---------- zeus/monitor/energy.py | 1 - zeus/optimizer/power_limit.py | 6 ++-- zeus/utils/framework.py | 32 ++++++++++++---------- 4 files changed, 33 insertions(+), 34 deletions(-) diff --git a/examples/power_limit_optimizer/train_dp.py b/examples/power_limit_optimizer/train_dp.py index 18904a05..1fb0c797 100644 --- a/examples/power_limit_optimizer/train_dp.py +++ b/examples/power_limit_optimizer/train_dp.py @@ -197,21 +197,19 @@ def main(): sampler=val_sampler, ) - # The rank 0 process will monitor and optimize the power limit of all GPUs. - if args.gpu == 0: - callback_set: list[Callback] = [ - GlobalPowerLimitOptimizer( - monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank). - optimum_selector=MaxSlowdownConstraint( - factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1), - ), - warmup_steps=10, - profile_steps=40, - pl_step=25, - ) - ] - else: - callback_set = [] + # All proceses will monitor and optimize the power limit of all GPUs (one process per GPU). + callback_set: list[Callback] = [ + GlobalPowerLimitOptimizer( + monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank). + optimum_selector=MaxSlowdownConstraint( + factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1), + ), + warmup_steps=10, + profile_steps=40, + pl_step=25, + ) + ] + callbacks = CallbackSet(callback_set) for epoch in range(args.epochs): diff --git a/zeus/monitor/energy.py b/zeus/monitor/energy.py index dc639510..d485021a 100644 --- a/zeus/monitor/energy.py +++ b/zeus/monitor/energy.py @@ -187,7 +187,6 @@ def __init__( except ZeusCPUInitError: self.cpus = EmptyCPUs() except ZeusCPUNoPermissionError as err: - self.cpus = EmptyCPUs() if cpu_indices: raise RuntimeError( "Root privilege is required to read RAPL metrics. See " diff --git a/zeus/optimizer/power_limit.py b/zeus/optimizer/power_limit.py index 73234e02..ad0ea24b 100644 --- a/zeus/optimizer/power_limit.py +++ b/zeus/optimizer/power_limit.py @@ -419,10 +419,10 @@ def on_step_begin(self) -> None: self.measurements.append( PowerLimitMeasurement( power_limit=self.state.current_power_limit // 1000, - energy=all_reduce( + energy=sum(all_reduce( list(measurement.gpu_energy.values()), operation="sum" - ), - time=all_reduce([measurement.time], operation="max"), + )), + time=max(all_reduce([measurement.time], operation="max")), ) ) # If we're done profiling all power limits, compute the optimal diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index b56b5c56..9610f9f7 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -106,15 +106,17 @@ def sync_execution( def all_reduce( object: list[int] | list[float], operation: Literal["sum", "max"] -) -> int | float: +) -> list[int] | list[float]: """Reduce objects from all replicas through the specified operation. - - If running in a distributed setting, the objects are reduced across all replicas. - If running in a non-distributed setting, the operation is just done on the single object. - """ + + If the current execution is not distributed, the object is returned as is.""" if torch_is_available(ensure_cuda=False): torch = MODULE_CACHE["torch"] + # if torch.distributed is not available or not initialized, return the object as is + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return object + # wrap object in a tensor if it is not already if not isinstance(object, torch.Tensor): object = torch.Tensor(object) @@ -129,17 +131,16 @@ def all_reduce( else: raise ValueError(f"all_reduce unsupported operation: {operation}") - # compute local operation - result = torch_func(object) - - # all-reduce only if torch.distributed is available and initialized - if torch.distributed.is_available() and torch.distributed.is_initialized(): - torch.distributed.all_reduce(result.cuda(), op=torch_op) - return result.item() + torch.distributed.all_reduce(object.cuda(), op=torch_op) + return object.cpu().tolist() if jax_is_available(): - # JAX cross-device all-reduce not yet implemente - return sum(object) if operation == "sum" else max(object) + # Check if not distributed + jax = MODULE_CACHE["jax"] + if jax.process_count() == 1: + return object + + raise NotImplementedError("JAX all-reduce not yet implemented") raise RuntimeError("No framework is available.") @@ -150,5 +151,6 @@ def is_distributed() -> bool: torch = MODULE_CACHE["torch"] return torch.distributed.is_available() and torch.distributed.is_initialized() if jax_is_available(): - return False # JAX not yet implemented + jax = MODULE_CACHE["jax"] + return jax.process_count() > 1 return False