Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
parthraut committed Dec 13, 2024
1 parent 66573ec commit 2f41596
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 34 deletions.
28 changes: 13 additions & 15 deletions examples/power_limit_optimizer/train_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,19 @@ def main():
sampler=val_sampler,
)

# The rank 0 process will monitor and optimize the power limit of all GPUs.
if args.gpu == 0:
callback_set: list[Callback] = [
GlobalPowerLimitOptimizer(
monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank).
optimum_selector=MaxSlowdownConstraint(
factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1),
),
warmup_steps=10,
profile_steps=40,
pl_step=25,
)
]
else:
callback_set = []
# All proceses will monitor and optimize the power limit of all GPUs (one process per GPU).
callback_set: list[Callback] = [
GlobalPowerLimitOptimizer(
monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank).
optimum_selector=MaxSlowdownConstraint(
factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1),
),
warmup_steps=10,
profile_steps=40,
pl_step=25,
)
]

callbacks = CallbackSet(callback_set)

for epoch in range(args.epochs):
Expand Down
1 change: 0 additions & 1 deletion zeus/monitor/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def __init__(
except ZeusCPUInitError:
self.cpus = EmptyCPUs()
except ZeusCPUNoPermissionError as err:
self.cpus = EmptyCPUs()
if cpu_indices:
raise RuntimeError(
"Root privilege is required to read RAPL metrics. See "
Expand Down
6 changes: 3 additions & 3 deletions zeus/optimizer/power_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,10 @@ def on_step_begin(self) -> None:
self.measurements.append(
PowerLimitMeasurement(
power_limit=self.state.current_power_limit // 1000,
energy=all_reduce(
energy=sum(all_reduce(
list(measurement.gpu_energy.values()), operation="sum"
),
time=all_reduce([measurement.time], operation="max"),
)),
time=max(all_reduce([measurement.time], operation="max")),
)
)
# If we're done profiling all power limits, compute the optimal
Expand Down
32 changes: 17 additions & 15 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,17 @@ def sync_execution(

def all_reduce(
object: list[int] | list[float], operation: Literal["sum", "max"]
) -> int | float:
) -> list[int] | list[float]:
"""Reduce objects from all replicas through the specified operation.
If running in a distributed setting, the objects are reduced across all replicas.
If running in a non-distributed setting, the operation is just done on the single object.
"""
If the current execution is not distributed, the object is returned as is."""
if torch_is_available(ensure_cuda=False):
torch = MODULE_CACHE["torch"]

# if torch.distributed is not available or not initialized, return the object as is
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
return object

# wrap object in a tensor if it is not already
if not isinstance(object, torch.Tensor):
object = torch.Tensor(object)
Expand All @@ -129,17 +131,16 @@ def all_reduce(
else:
raise ValueError(f"all_reduce unsupported operation: {operation}")

# compute local operation
result = torch_func(object)

# all-reduce only if torch.distributed is available and initialized
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(result.cuda(), op=torch_op)
return result.item()
torch.distributed.all_reduce(object.cuda(), op=torch_op)
return object.cpu().tolist()

if jax_is_available():
# JAX cross-device all-reduce not yet implemente
return sum(object) if operation == "sum" else max(object)
# Check if not distributed
jax = MODULE_CACHE["jax"]
if jax.process_count() == 1:
return object

raise NotImplementedError("JAX all-reduce not yet implemented")

raise RuntimeError("No framework is available.")

Expand All @@ -150,5 +151,6 @@ def is_distributed() -> bool:
torch = MODULE_CACHE["torch"]
return torch.distributed.is_available() and torch.distributed.is_initialized()
if jax_is_available():
return False # JAX not yet implemented
jax = MODULE_CACHE["jax"]
return jax.process_count() > 1
return False

0 comments on commit 2f41596

Please sign in to comment.