From 4af774ba405363e932eb03e8ba2d8e43d9b65d5a Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 17 Dec 2024 17:53:19 +0100 Subject: [PATCH] chore: register_model() -> dml.wrap_ddp() --- dmlcloud/core/__init__.py | 4 +++ dmlcloud/core/model.py | 55 ++++++++++++++++++++++++++++++++++++++ dmlcloud/core/pipeline.py | 29 -------------------- dmlcloud/util/logging.py | 1 - examples/barebone_mnist.py | 3 ++- 5 files changed, 61 insertions(+), 31 deletions(-) create mode 100644 dmlcloud/core/model.py diff --git a/dmlcloud/core/__init__.py b/dmlcloud/core/__init__.py index b5ed1bf..d16c8e0 100644 --- a/dmlcloud/core/__init__.py +++ b/dmlcloud/core/__init__.py @@ -3,6 +3,7 @@ from .distributed import * from .metrics import * from .logging import * +from .model import * __all__ = [] @@ -18,3 +19,6 @@ # Logging __all__ += logging.__all__ + +# Model helpers +__all__ += model.__all__ diff --git a/dmlcloud/core/model.py b/dmlcloud/core/model.py new file mode 100644 index 0000000..8a9f9eb --- /dev/null +++ b/dmlcloud/core/model.py @@ -0,0 +1,55 @@ +import torch +from torch import nn + +from . import logging as dml_logging + + +__all__ = [ + 'count_parameters', + 'wrap_ddp', +] + + +def count_parameters(module: nn.Module) -> int: + """ + Returns the number of trainable parameters in a module. + + Args: + module (nn.Module): The module to count the parameters of. + + Returns: + int: The number of trainable parameters. + """ + return sum(p.numel() for p in module.parameters() if p.requires_grad) + + +def wrap_ddp(module: nn.Module, device: torch.device, sync_bn: bool = False, verbose: bool = True) -> nn.Module: + """ + Wraps a module with DistributedDataParallel. + + Args: + module (nn.Module): The module to wrap. + device (torch.device): The device to use. + sync_bn (bool, optional): If True, uses SyncBatchNorm. Default is False. + verbose (bool, optional): If True, prints information about the model. Default is True. + + Returns: + nn.Module: The wrapped module. + """ + + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + raise RuntimeError('DistributedDataParallel requires torch.distributed to be initialized.') + + module = module.to(device) # Doing it in this order is important for SyncBN + if sync_bn: + module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) + + device_ids = [device] if device.type == 'cuda' else None # Must be None for cpu devices + ddp = nn.parallel.DistributedDataParallel(module, broadcast_buffers=False, device_ids=device_ids) + if verbose: + msg = f'* MODEL:\n' + msg += f' - Parameters: {count_parameters(module) / 1e6:.1f} kk\n' + msg += f' - {module}' + dml_logging.info(msg) + + return ddp \ No newline at end of file diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index 57e5815..c09ec05 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -57,35 +57,6 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option def checkpointing_enabled(self): return self.checkpoint_dir is not None - def register_model( - self, - name: str, - model: torch.nn.Module, - use_ddp: bool = True, - sync_bn: bool = False, - save_latest: bool = True, - save_interval: Optional[int] = None, - save_best: bool = False, - best_metric: str = 'val/loss', - verbose: bool = True, - ): - if name in self.models: - raise ValueError(f'Model with name {name} already exists') - model = model.to(self.device) # Doing it in this order is important for SyncBN - if sync_bn: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if use_ddp: - device_ids = [self.device] if self.device.type == 'cuda' else None # Must be None for cpu devices - model = DistributedDataParallel(model, broadcast_buffers=False, device_ids=device_ids) - self.models[name] = model - - if verbose: - msg = f'Model "{name}":\n' - msg += f' - Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f} kk\n' - msg += f' - DDP: {use_ddp}\n' - msg += f' - {model}' - dml_logging.info(msg) - def register_optimizer(self, name: str, optimizer, scheduler=None): if name in self.optimizers: raise ValueError(f'Optimizer with name {name} already exists') diff --git a/dmlcloud/util/logging.py b/dmlcloud/util/logging.py index 73cf5a7..48b1897 100644 --- a/dmlcloud/util/logging.py +++ b/dmlcloud/util/logging.py @@ -1,5 +1,4 @@ import io -import logging import os import subprocess import sys diff --git a/examples/barebone_mnist.py b/examples/barebone_mnist.py index d11a62d..295d2b4 100644 --- a/examples/barebone_mnist.py +++ b/examples/barebone_mnist.py @@ -31,7 +31,8 @@ def pre_stage(self): nn.MaxPool2d(2), nn.Flatten(), nn.Linear(784, 10), - ).to(self.pipeline.device) + ) + self.model = dml.wrap_ddp(self.model, self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) self.loss = nn.CrossEntropyLoss()