Skip to content

Commit

Permalink
feat: scale_lr()
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 17, 2024
1 parent 4af774b commit 1001a49
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions dmlcloud/core/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
from torch import nn

from . import logging as dml_logging
from . import distributed as dml_distributed, logging as dml_logging


__all__ = [
'count_parameters',
'wrap_ddp',
'scale_lr',
]


Expand All @@ -23,14 +24,21 @@ def count_parameters(module: nn.Module) -> int:
return sum(p.numel() for p in module.parameters() if p.requires_grad)


def wrap_ddp(module: nn.Module, device: torch.device, sync_bn: bool = False, verbose: bool = True) -> nn.Module:
def wrap_ddp(
module: nn.Module,
device: torch.device,
sync_bn: bool = False,
find_unused_parameters: bool = False,
verbose: bool = True,
) -> nn.Module:
"""
Wraps a module with DistributedDataParallel.
Args:
module (nn.Module): The module to wrap.
device (torch.device): The device to use.
sync_bn (bool, optional): If True, uses SyncBatchNorm. Default is False.
find_unused_parameters (bool, optional): If True, finds unused parameters. Default is False.
verbose (bool, optional): If True, prints information about the model. Default is True.
Returns:
Expand All @@ -39,17 +47,39 @@ def wrap_ddp(module: nn.Module, device: torch.device, sync_bn: bool = False, ver

if not torch.distributed.is_available() or not torch.distributed.is_initialized():
raise RuntimeError('DistributedDataParallel requires torch.distributed to be initialized.')

module = module.to(device) # Doing it in this order is important for SyncBN
if sync_bn:
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)

device_ids = [device] if device.type == 'cuda' else None # Must be None for cpu devices
ddp = nn.parallel.DistributedDataParallel(module, broadcast_buffers=False, device_ids=device_ids)
ddp = nn.parallel.DistributedDataParallel(
module, broadcast_buffers=False, device_ids=device_ids, find_unused_parameters=find_unused_parameters
)
if verbose:
msg = f'* MODEL:\n'
msg += f' - Parameters: {count_parameters(module) / 1e6:.1f} kk\n'
msg += f' - {module}'
dml_logging.info(msg)

return ddp
return ddp


def scale_lr(base_lr: float, world_size: int = None) -> float:
"""
Scales the learning rate based on the world size.
Args:
base_lr (float): The base learning rate.
world_size (int, optional): The number of processes. Default is the global world size.
Returns:
float: The scaled learning rate.
See Also:
- :func:`dmlcloud.`
"""
if world_size is None:
world_size = dml_distributed.world_size()

return base_lr * world_size

0 comments on commit 1001a49

Please sign in to comment.