Skip to content

Commit

Permalink
chore: register_model() -> dml.wrap_ddp()
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 17, 2024
1 parent 8d4027b commit 4af774b
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 31 deletions.
4 changes: 4 additions & 0 deletions dmlcloud/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .distributed import *
from .metrics import *
from .logging import *
from .model import *

__all__ = []

Expand All @@ -18,3 +19,6 @@

# Logging
__all__ += logging.__all__

# Model helpers
__all__ += model.__all__
55 changes: 55 additions & 0 deletions dmlcloud/core/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from torch import nn

from . import logging as dml_logging


__all__ = [
'count_parameters',
'wrap_ddp',
]


def count_parameters(module: nn.Module) -> int:
"""
Returns the number of trainable parameters in a module.
Args:
module (nn.Module): The module to count the parameters of.
Returns:
int: The number of trainable parameters.
"""
return sum(p.numel() for p in module.parameters() if p.requires_grad)


def wrap_ddp(module: nn.Module, device: torch.device, sync_bn: bool = False, verbose: bool = True) -> nn.Module:
"""
Wraps a module with DistributedDataParallel.
Args:
module (nn.Module): The module to wrap.
device (torch.device): The device to use.
sync_bn (bool, optional): If True, uses SyncBatchNorm. Default is False.
verbose (bool, optional): If True, prints information about the model. Default is True.
Returns:
nn.Module: The wrapped module.
"""

if not torch.distributed.is_available() or not torch.distributed.is_initialized():
raise RuntimeError('DistributedDataParallel requires torch.distributed to be initialized.')

module = module.to(device) # Doing it in this order is important for SyncBN
if sync_bn:
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)

device_ids = [device] if device.type == 'cuda' else None # Must be None for cpu devices
ddp = nn.parallel.DistributedDataParallel(module, broadcast_buffers=False, device_ids=device_ids)
if verbose:
msg = f'* MODEL:\n'
msg += f' - Parameters: {count_parameters(module) / 1e6:.1f} kk\n'
msg += f' - {module}'
dml_logging.info(msg)

return ddp
29 changes: 0 additions & 29 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,6 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option
def checkpointing_enabled(self):
return self.checkpoint_dir is not None

def register_model(
self,
name: str,
model: torch.nn.Module,
use_ddp: bool = True,
sync_bn: bool = False,
save_latest: bool = True,
save_interval: Optional[int] = None,
save_best: bool = False,
best_metric: str = 'val/loss',
verbose: bool = True,
):
if name in self.models:
raise ValueError(f'Model with name {name} already exists')
model = model.to(self.device) # Doing it in this order is important for SyncBN
if sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if use_ddp:
device_ids = [self.device] if self.device.type == 'cuda' else None # Must be None for cpu devices
model = DistributedDataParallel(model, broadcast_buffers=False, device_ids=device_ids)
self.models[name] = model

if verbose:
msg = f'Model "{name}":\n'
msg += f' - Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f} kk\n'
msg += f' - DDP: {use_ddp}\n'
msg += f' - {model}'
dml_logging.info(msg)

def register_optimizer(self, name: str, optimizer, scheduler=None):
if name in self.optimizers:
raise ValueError(f'Optimizer with name {name} already exists')
Expand Down
1 change: 0 additions & 1 deletion dmlcloud/util/logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import logging
import os
import subprocess
import sys
Expand Down
3 changes: 2 additions & 1 deletion examples/barebone_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def pre_stage(self):
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(784, 10),
).to(self.pipeline.device)
)
self.model = dml.wrap_ddp(self.model, self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
self.loss = nn.CrossEntropyLoss()

Expand Down

0 comments on commit 4af774b

Please sign in to comment.