From 8e0864f2c5f0a552cbf11c91439ffd56654ffd79 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Fri, 9 Aug 2024 16:24:01 +0800 Subject: [PATCH] make AcceleratRunner a subclass of Accelerator add TorchRunner add DeepSpeedRunner --- .github/workflows/push.yaml | 2 +- danling/__init__.py | 9 +- danling/{runner => }/defaults.py | 15 +- danling/metrics/preprocesses.py | 2 + danling/modules/mlp/dense.py | 6 +- danling/runner/README.md | 14 +- danling/runner/__init__.py | 9 +- danling/runner/accelerate_runner.py | 608 +++------------- danling/runner/base_runner.py | 550 +++++++-------- danling/runner/{state.py => config.py} | 56 +- danling/runner/deepspeed_runner.py | 213 ++++++ danling/runner/torch_runner.py | 664 +++++++++++++++++- danling/runner/utils.py | 41 ++ danling/tensors/nested_tensor.py | 3 + danling/utils/__init__.py | 2 - demo/accelerate_imdb.py | 120 ++++ demo/{vision => }/torch_mnist.py | 15 +- docs/docs/runner/{state.md => config.md} | 4 +- docs/mkdocs.yml | 3 +- pyproject.toml | 13 +- requirements.txt | 12 + tests/optim/test_lr_scheduler.py | 4 +- tests/runner/test_base_runner.py | 7 +- .../defaults.py => tests/runner/test_imdb.py | 17 +- tests/runner/test_mnist.py | 2 +- 25 files changed, 1535 insertions(+), 856 deletions(-) rename danling/{runner => }/defaults.py (81%) rename danling/runner/{state.py => config.py} (71%) create mode 100644 danling/runner/deepspeed_runner.py create mode 100644 demo/accelerate_imdb.py rename demo/{vision => }/torch_mnist.py (88%) rename docs/docs/runner/{state.md => config.md} (57%) rename danling/utils/defaults.py => tests/runner/test_imdb.py (66%) diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml index e4135ecf..cf132846 100644 --- a/.github/workflows/push.yaml +++ b/.github/workflows/push.yaml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: pip install -r requirements.txt && pip install -e . - name: Install dependencies for testing - run: pip install pytest pytest-cov torch torcheval torchmetrics torchvision accelerate + run: pip install pytest pytest-cov - name: pytest run: pytest --cov=materialx --cov-report=xml --cov-report=html . - name: Upload coverage report for documentation diff --git a/danling/__init__.py b/danling/__init__.py index 455a4ca3..511a739a 100644 --- a/danling/__init__.py +++ b/danling/__init__.py @@ -17,7 +17,7 @@ from lazy_imports import try_import -from danling import metrics, modules, optim, registry, runner, tensors, typing, utils +from danling import defaults, metrics, modules, optim, registry, runner, tensors, typing, utils from .metrics import ( AverageMeter, @@ -29,7 +29,7 @@ ) from .optim import LRScheduler from .registry import GlobalRegistry, Registry -from .runner import AccelerateRunner, BaseRunner, TorchRunner +from .runner import AccelerateRunner, BaseRunner, Config, DeepSpeedRunner, TorchRunner from .tensors import NestedTensor, PNTensor, tensor from .utils import ( catch, @@ -47,6 +47,7 @@ from .metrics import Metrics, MultiTaskMetrics __all__ = [ + "defaults", "metrics", "modules", "optim", @@ -55,9 +56,11 @@ "tensors", "utils", "typing", + "Config", "BaseRunner", - "AccelerateRunner", "TorchRunner", + "AccelerateRunner", + "DeepSpeedRunner", "LRScheduler", "Registry", "GlobalRegistry", diff --git a/danling/runner/defaults.py b/danling/defaults.py similarity index 81% rename from danling/runner/defaults.py rename to danling/defaults.py index 59433429..0e507b27 100644 --- a/danling/runner/defaults.py +++ b/danling/defaults.py @@ -15,14 +15,15 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the LICENSE file for more details. -DEFAULT_RUN_NAME = "Run" -DEFAULT_EXPERIMENT_NAME = "DanLing" -DEFAULT_EXPERIMENT_ID = "xxxxxxxxxxxxxxxx" -DEFAULT_IGNORED_KEYS_IN_HASH = { +RUN_NAME = "Run" +EXPERIMENT_NAME = "DanLing" +EXPERIMENT_ID = "xxxxxxxxxxxxxxxx" +SEED = 1016 +IGNORED_CONFIG_IN_HASH = { "timestamp", - "iters", - "steps", - "epochs", + "iter", + "step", + "epoch", "results", "score_split", "score", diff --git a/danling/metrics/preprocesses.py b/danling/metrics/preprocesses.py index 6467ad81..f214a2eb 100644 --- a/danling/metrics/preprocesses.py +++ b/danling/metrics/preprocesses.py @@ -58,6 +58,8 @@ def preprocess( input, target = input.concat, target.concat if ignored_index is not None: input, target = input[target != ignored_index], target[target != ignored_index] + if input.numel() == target.numel(): + return input.squeeze(), target.squeeze() return input, target diff --git a/danling/modules/mlp/dense.py b/danling/modules/mlp/dense.py index bddf31a1..31c0bbe4 100644 --- a/danling/modules/mlp/dense.py +++ b/danling/modules/mlp/dense.py @@ -33,10 +33,10 @@ def __init__( super().__init__() self.residual = residual self.linear = nn.Linear(in_features, out_features, bias=bias) - self.norm = getattr(nn, norm)(out_features) if norm else nn.Identity() - self.activation = getattr(nn, activation)() if activation else nn.Identity() + self.norm = getattr(nn, norm)(out_features) if norm else None + self.activation = getattr(nn, activation)() if activation else None self.dropout = nn.Dropout(dropout) - self.pool = getattr(nn, pool)(out_features) if pool else nn.Identity() if self.residual else None + self.pool = getattr(nn, pool)(out_features) if self.residual else None def forward(self, x): out = self.linear(x) diff --git a/danling/runner/README.md b/danling/runner/README.md index ff5e151e..d68b0173 100644 --- a/danling/runner/README.md +++ b/danling/runner/README.md @@ -4,26 +4,26 @@ The Runner of DanLing sets up the basic environment for running neural networks. ## Components -For cross-platform compatibilities, DanLing features a two-level Runner + RunnerState system. +For cross-platform compatibilities, DanLing features a two-level Runner + Config system. ### PlatformRunner PlatformRunner implements platform-specific features like `step` and `prepare`. -The Runner contains all runtime information that is irrelevant to the checkpoint (e.g. `world_size`, `rank`, etc.). All other information should be saved in `RunnerState`. +The Runner contains all runtime information that is irrelevant to the checkpoint (e.g. `world_size`, `rank`, etc.). All other information should be saved in `Config`. Currently, only [`AccelerateRunner`][danling.runner.AccelerateRunner] is supported. ### [`BaseRunner`][danling.runner.BaseRunner] -[`BaseRunner`](danling.runner.BaseRunner) defines shared attributes and implements platform-agnostic features, including `init_logging`, `results` and `scores`. +[`BaseRunner`][danling.runner.BaseRunner] defines shared attributes and implements platform-agnostic features, including `init_logging`, `results` and `scores`. -### [`RunnerState`][danling.runner.RunnerState] +### [`Config`][danling.runner.Config] -[`RunnerState`][danling.runner.RunnerState] stores the state of a run (e.g. `epochs`, `run_id`, `network`, etc.). +[`Config`][danling.runner.Config] stores the state of a run (e.g. `epoch`, `run_id`, `network`, etc.). -With `RunnerState` and corresponding weights, you can resume a run from any point. -Therefore, all members in `RunnerState` will be saved in the checkpoint, and thus should be json serialisable. +With `Config` and corresponding weights, you can resume a run from any point. +Therefore, all members in `Config` will be saved in the checkpoint, and thus should be json serialisable. ## Experiments Management diff --git a/danling/runner/__init__.py b/danling/runner/__init__.py index 7aff1d60..b27788e2 100644 --- a/danling/runner/__init__.py +++ b/danling/runner/__init__.py @@ -15,19 +15,20 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the LICENSE file for more details. -from . import defaults from .accelerate_runner import AccelerateRunner from .base_runner import BaseRunner -from .state import RunnerState +from .config import Config +from .deepspeed_runner import DeepSpeedRunner from .torch_runner import TorchRunner from .utils import on_local_main_process, on_main_process __all__ = [ - "RunnerState", + "Config", "BaseRunner", + "TorchRunner", "AccelerateRunner", + "DeepSpeedRunner", "TorchRunner", "on_main_process", "on_local_main_process", - "defaults", ] diff --git a/danling/runner/accelerate_runner.py b/danling/runner/accelerate_runner.py index cae5b0e2..ae30f6fd 100644 --- a/danling/runner/accelerate_runner.py +++ b/danling/runner/accelerate_runner.py @@ -18,35 +18,30 @@ from __future__ import annotations import os -import random -from collections.abc import Callable, Mapping -from contextlib import suppress -from time import time -from typing import Any -from warnings import warn - -# pylint: disable=redefined-builtin +from collections.abc import Mapping +from contextlib import contextmanager +from math import ceil + import torch -from accelerate import Accelerator -from accelerate.utils import DeepSpeedPlugin -from chanfig import NestedDict +from chanfig import FlatDict, NestedDict +from lazy_imports import try_import from torch import distributed as dist -from torch import nn, optim, utils -from torch.backends import cudnn -from tqdm import tqdm +from torch import nn, utils try: - from numpy import random as np_random + from functools import cached_property except ImportError: - np_random = None + from cached_property import cached_property # type: ignore -from danling.utils import catch +from .config import Config +from .torch_runner import BaseRunner, TorchRunner -from .base_runner import BaseRunner -from .utils import RunnerMode, on_main_process +with try_import() as ac: + from accelerate import Accelerator + from accelerate.utils import DeepSpeedPlugin -class AccelerateRunner(BaseRunner): # pylint: disable=too-many-public-methods +class AccelerateRunner(TorchRunner, Accelerator): # pylint: disable=too-many-public-methods r""" Set up everything for running a job. @@ -60,518 +55,145 @@ class AccelerateRunner(BaseRunner): # pylint: disable=too-many-public-methods In fact, you don't even need to create `dataloaders`, just define `datasets` and `AccelerateRunner` will create `dataloaders` for you. `AccelerateRunner` will inspect the `train` flag in corresponding dataset to - automatically set `shuffle`. - - Attributes: - accelerator (Accelerator): - accelerate: Arguments to pass when building accelerator. Defaults to `{}`. + set `shuffle` and `drop_last` automatically. """ - accelerator: Accelerator - accelerate: dict - - model: nn.Module - criterion: nn.Module - optimizer: optim.Optimizer - scheduler: optim.lr_scheduler._LRScheduler - - def __init__(self, *args, **kwargs) -> None: - if len(args) != 1 or kwargs: - message = ( - "Passing multiple args & kwargs to build Runner is deprecated and will be removed in DanLing v0.3.\n" - "Please only pass a config dict instead." - ) - warn(message, DeprecationWarning, stacklevel=2) - config = NestedDict(*args, **kwargs) - else: - config = args[0] - if "accelerate" not in self: # class attributes - self.accelerate = {} - self.accelerate.update(config.get("accelerate", {})) - super().__init__(config) + _accelerate: FlatDict | None = None + + def __init__(self, config: Config) -> None: + ac.check() + TorchRunner.__init__(self, config) + Accelerator.__init__(self, **self.accelerate) + if self.distributed: + object_list = [self.id, self.timestamp] + dist.broadcast_object_list(object_list) + self.id, self.timestamp = object_list def __post_init__(self) -> None: - self.model, self.criterion, self.optimizer = self.prepare(self.model, self.criterion, self.optimizer) - self.scheduler = self.prepare(self.scheduler) + BaseRunner.__post_init__(self) + self.project_configuration.set_directories(self.dir) if self.datasets: - datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders} - default_kwargs = self.state.setdefault("dataloader", NestedDict()) - dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs}) - for k, d in datasets.items(): - dataloader_kwargs.setdefault(k, NestedDict()) - dataloader_kwargs[k].merge(default_kwargs, overwrite=False) - dataloader_kwargs[k].setdefault("shuffle", getattr(d, "train", True)) - dataloader_kwargs[k].setdefault("drop_last", not getattr(d, "train", True)) - self.dataloaders[k] = utils.data.DataLoader(d, **dataloader_kwargs[k]) - default_kwargs.update(dataloader_kwargs) - for k, d in self.dataloaders.items(): - self.dataloaders[k] = self.prepare(d) - if self.state.get("log_interval") is None: - self.state.log_interval = max(max(len(d) for d in self.dataloaders.values()) // 10, 1) - - @property - def deepspeed(self) -> dict | None: - if "accelerator" not in self: - raise ValueError("accelerator is not used") - if self.accelerator.state.deepspeed_plugin is not None: - return self.accelerator.state.deepspeed_plugin.deepspeed_config - return None - - def train(self, train_splits: list[str] | None = None, eval_splits: list[str] | None = None) -> NestedDict: - r""" - Perform training on `split`. - - Args: - train_splits (list[str]): list of split to run train. - Defaults to `["train"]`. - eval_splits (list[str]): list of split to run evaluate. - Defaults to `self.dataloaders` except for those in `train_splits`. - - Return: - NestedDict: train results - """ - - early_stop_counter = 0 - if train_splits is None: - train_splits = ["train"] - if eval_splits is None: - eval_splits = [s for s in self.dataloaders if s not in train_splits] - self.state.epoch_begin = self.state.epochs - print(f"Begin training from {self.state.epoch_begin} to {self.state.epoch_end}") - print(f"Training splits: {train_splits}") - print(f"Evaluation splits: {eval_splits}") - patience = self.state.get("patience", float("inf")) - for epochs in range(self.state.epoch_begin, self.state.epoch_end): # type: ignore - self.state.epochs = epochs - result = NestedDict() - result.setattr("convert_mapping", True) - for split in train_splits: - result[split] = self.train_epoch(split) - for split in eval_splits: - result[split] = self.evaluate_epoch(split) - self.append_result(result) - print(self.format_epoch_result(result)) - self.save_result() - if self.state.save_interval is not None: - self.save_checkpoint(epochs) - """@nni.report_intermediate_result(self.latest_score)""" - early_stop_counter = 0 if self.is_best else early_stop_counter + 1 - if early_stop_counter > patience: - print("early stop") - break - """@nni.report_final_result(self.latest_score)""" - return self.results - - def train_epoch(self, split: str = "train") -> NestedDict: - r""" - Train one epoch on `split`. - - Args: - split (str): split to run train - - Return: - NestedDict: train result - """ - - self.mode = "train" # type: ignore - self.split = split - loader = self.dataloaders[split] - length = len(loader) - 1 - last_print_iteration = -1 - log_interval = self.state.get("log_interval", -1) - self.meters.reset() - if self.metrics is not None: - self.metrics.reset() - batch_time = time() - if hasattr(loader.batch_sampler, "set_epoch"): - loader.batch_sampler.set_epoch(self.epochs) - if hasattr(loader.sampler, "set_epoch"): - loader.sampler.set_epoch(self.epochs) - - for iteration, data in enumerate(loader): - with self.autocast(), self.accumulate(): - input = data["input"] if isinstance(data, Mapping) else data[0] - target = data["target"] if isinstance(data, Mapping) else data[1] - pred = self.model(**input) if isinstance(input, Mapping) else self.model(input) - loss = self.criterion(pred, target) - if self.metrics is not None: - self.metrics.update(pred.squeeze(-1), target) - self.step(loss) - - if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length): - interval = iteration - last_print_iteration - if self.device == torch.device("cuda"): - torch.cuda.synchronize() - if self.scheduler is not None: - self.meters.lr.update(self.scheduler.get_last_lr()[0]) - self.meters.time.update((time() - batch_time) / interval) - batch_time = time() - reduced_loss = self.reduce(loss).item() - self.meters.loss.update(reduced_loss) - self.step_log(split, iteration, length) - last_print_iteration = iteration - - result = self.meters.average() - if self.metrics is not None: - result.merge(self.metrics.average()) - return result - - def evaluate(self, eval_splits: list[str] | None = None) -> NestedDict: - r""" - Perform evaluation on `eval_splits`. - - Args: - eval_splits (list[str]): list of split to run evaluate. - Defaults to `["eval"]`. - - Return: - NestedDict: evaluation result - """ - - if eval_splits is None: - eval_splits = ["eval"] - - print("Begin evaluation") - print(f"Evaluation splits: {eval_splits}") - result = NestedDict() - result.setattr("convert_mapping", True) - for split in eval_splits: - result[split] = self.evaluate_epoch(split=split) - print(self.format_epoch_result(result)) - return result - - @torch.inference_mode() - def evaluate_epoch(self, split: str = "val") -> NestedDict: - r""" - Evaluate one epoch on `split`. - - Args: - split (str): split to run evaluate - - Return: - NestedDict: evaluation result - """ + self.build_dataloaders() + if self.config.get("log_interval") is None: + self.config.log_interval = max(ceil(max(len(d) for d in self.dataloaders.values()) / 10), 1) + self.model, self.criterion, self.optimizer, self.scheduler = self.prepare( + self.model, self.criterion, self.optimizer, self.scheduler + ) - self.mode = "eval" # type: ignore - self.split = split - loader = self.dataloaders[split] - length = len(loader) - 1 - last_print_iteration = -1 - log_interval = self.state.get("log_interval", -1) - self.meters.reset() - if self.metrics is not None: - self.metrics.reset() - batch_time = time() - - for iteration, data in enumerate(loader): + def train_step(self, data) -> torch.Tensor: + with self.autocast(), self.accumulate(): input = data["input"] if isinstance(data, Mapping) else data[0] target = data["target"] if isinstance(data, Mapping) else data[1] pred = self.model(**input) if isinstance(input, Mapping) else self.model(input) loss = self.criterion(pred, target) if self.metrics is not None: self.metrics.update(pred.squeeze(-1), target) + self.advance(loss) + return loss - if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length): - interval = iteration - last_print_iteration - if self.device == torch.device("cuda"): - torch.cuda.synchronize() - self.meters.time.update((time() - batch_time) / interval) - batch_time = time() - reduced_loss = self.reduce(loss).item() - self.meters.loss.update(reduced_loss) - self.step_log(split, iteration, length) - last_print_iteration = iteration - - result = self.meters.average() - if self.metrics is not None: - result.merge(self.metrics.average()) - self.write_result(result, split, self.state.epochs) - return result - - @torch.inference_mode() - def inference(self, split: str = "inf") -> list: - r""" - Perform inference on `split`. - - Args: - split (str): split to run inference - - Return: - Tensor: inference outputs - """ - - # pylint: disable=E1102, W0622 - self.mode = "inf" # type: ignore - loader = self.dataloaders[split] - self.meters.reset() - output = [] - for _, data in tqdm(enumerate(loader), total=len(loader)): - input = data["input"] if isinstance(data, Mapping) else data[0] - pred = self.model(**input) if isinstance(input, Mapping) else self.model(input) - output.extend(pred.squeeze(-1).tolist()) - - if self.distributed: - torch.cuda.synchronize() - output = self.gather_for_metrics(output) - return output - - def init_distributed(self) -> None: - r""" - Set up distributed training. - - Initialise process group and set up DDP variables. - """ - - if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true": - deepspeed_config = self.state.get("deepspeed", os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE")) - self.accelerate["deepspeed_plugin"] = DeepSpeedPlugin(hf_ds_config=self.init_deepspeed(deepspeed_config)) - self.accelerator = Accelerator(**self.accelerate) - if self.distributed: - object_list = [self.id, self.timestamp] - dist.broadcast_object_list(object_list) - self.id, self.timestamp = object_list - - @on_main_process - def init_tensorboard(self, *args, **kwargs) -> None: - r""" - Set up Tensoraoard SummaryWriter. - """ - from torch.utils.tensorboard.writer import SummaryWriter # pylint: disable=C0415 - - if "log_dir" not in kwargs: - kwargs["log_dir"] = self.dir - - self.writer = SummaryWriter(*args, **kwargs) - self.writer.add_scalar = catch(OSError, verbose=False)(self.writer.add_scalar) - - def set_seed(self, seed: int | None = None, bias: int | None = None) -> None: - r""" - Set up random seed. - - Args: - seed: Random seed to set. - Defaults to `self.state.seed` (`config.seed`). - - bias: Make the seed different for each processes. - This is used to ensure the data augmentation are applied differently on every processes. - Defaults to `self.rank`. - Set to `False` to disable this feature. - """ - - seed = seed or self.state.seed - if self.distributed: - object_list = [seed] - dist.broadcast_object_list(object_list) - seed = object_list[0] - bias = bias or self.rank - if bias: - seed += bias - self.state.seed = seed - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - if np_random is not None: - np_random.seed(seed) - random.seed(seed) - - def set_deterministic(self) -> None: - r""" - Set up deterministic. - """ - - cudnn.benchmark = False - cudnn.deterministic = True - if torch.__version__ >= "1.8.0": - torch.use_deterministic_algorithms(True) - - def step(self, loss, batch_size: int | None = None, zero_grad: bool = True) -> None: + def advance(self, loss) -> None: r""" Backward loss and step optimizer & scheduler. - This method increment `self.state.steps`. - - This method also increment `self.state.iters` when `batch_size` is specified. - Args: - zero_grad: Whether to zero the gradients. + loss: The loss tensor from which to backpropagate. """ - self.accelerator.backward(loss) + self.backward(loss) if self.sync_gradients: - if self.state.get("max_grad_value") is not None: - self.clip_grad_value_(self.model.parameters(), self.state.get("max_grad_value")) - if self.state.get("max_grad_norm") is not None: - self.clip_grad_norm_(self.model.parameters(), self.state.get("max_grad_norm")) - if self.optimizer is not None: - self.optimizer.step() - if zero_grad: - self.optimizer.zero_grad() + if self.config.get("max_grad_value") is not None: + self.clip_grad_value_(self.model.parameters(), self.config["max_grad_value"]) + if self.config.get("max_grad_norm") is not None: + self.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"]) + self.optimizer.step() if self.scheduler is not None: self.scheduler.step() - self.state.steps += 1 - if batch_size is None: - batch_size = self.batch_size_equivalent - self.state.iters += batch_size - # TODO: Support `drop_last = False` - # self.state.iters += self.batch_size_equivalent - - def state_dict(self, cls: Callable = dict) -> Mapping: - r""" - Return dict of all attributes for checkpoint. - """ - - if self.model is None: - raise ValueError("Model must be defined when calling state_dict") - model = self.accelerator.unwrap_model(self.model) - return cls( - runner=self.state.dict(), - model=model.state_dict(), - optimizer=self.optimizer.state_dict() if self.optimizer else None, - scheduler=self.scheduler.state_dict() if self.scheduler else None, - ) - - def prepare(self, *args: list[Any], device_placement: list[bool] | None = None) -> list[Any]: - r""" - Prepare all objects passed in `args` for distributed training and mixed precision, - then return them in the same order. - """ - - return self.accelerator.prepare(*args, device_placement=device_placement) - - def accumulate(self, model: nn.Module | None = None): - r""" - Context manager that enables gradient accumulate. - """ - - model = model or self.model - return self.accelerator.accumulate(model) - - def autocast(self): - r""" - Context manager that enables auto-casting for the forward pass (and maybe backward pass). - """ - - return self.accelerator.autocast() - - def backward(self, loss) -> None: - r""" - Backward loss to compute gradients. - """ - - return self.accelerator.backward(loss) - - def unwrap_model(self, model: nn.Module | None = None) -> nn.Module: - r""" - Unwrap DDP model. - - Args: - model (Optional[nn.Module]): - Defaults to `self.model`. - """ - - if model is not None: - model = self.model - if self.accelerator is not None: - return self.accelerator.unwrap_model(model) - if self.distributed: - return model.module - return model + if self.ema is not None: + self.ema.update() + self.optimizer.zero_grad() + self.config.steps = self.step + self.config.iters += 1 - @property - def mode(self) -> RunnerMode: - return self._mode - - @mode.setter - def mode(self, mode: str | RunnerMode) -> None: - if isinstance(mode, str): - mode = RunnerMode(mode) - self._mode = mode - if self.model is not None: - self.model.train(mode == RunnerMode.train) + def unwrap(self, model: nn.Module) -> nn.Module: + return self.unwrap_model(model) @property - def batch_size(self) -> int: - r""" - Batch size. + def accelerate(self) -> FlatDict: + if self._accelerate is None: + self._accelerate = self.get_accelerate_config(self.config) + return self._accelerate - Notes: - If `train` is in `dataloaders`, then `batch_size` is the batch size of `train`. - Otherwise, `batch_size` is the batch size of the first dataloader. - - Returns: - (int): - """ - - batch_size = self.state.get("dataloader.batch_size") - if batch_size: - return batch_size - if self.dataloaders: - loader = self.dataloaders.get("train", next(iter(self.dataloaders.values()))) - if loader.batch_size: - return loader.batch_size - batch_sampler = loader.batch_sampler if loader.batch_sampler is not None else loader.sampler - return batch_sampler.batch_size - raise AttributeError("batch_size could not be inferred, since no dataloader found.") + @accelerate.setter + def accelerate(self, config: FlatDict) -> None: + self._accelerate = config @property - def accum_steps(self) -> int: - r""" - Gradient accumulation steps. - - Returns: - (int): - """ + def deepspeed(self) -> dict | None: + if self.state.deepspeed_plugin is not None: + return self.state.deepspeed_plugin.deepspeed_config + return None - return self.accelerator.gradient_accumulation_steps + @contextmanager + def accumulate(self, *models: nn.Module): + if not models: + models = (self.model,) + yield Accelerator.accumulate(self, *models) @property def device(self) -> torch.device: - r""" - Device of runner. - """ - - return self.accelerator.device + return self.state.device @property def world_size(self) -> int: - r""" - Number of Processes. - """ - - return self.accelerator.num_processes + if "state" in self.__dict__: + return self.state.num_processes + return 1 @property def rank(self) -> int: - r""" - Process index in all processes. - """ - - return self.accelerator.process_index + if "state" in self.__dict__: + return self.state.process_index + return 0 @property def local_rank(self) -> int: - r""" - Process index in local processes. - """ - - return self.accelerator.local_process_index - - def gather(self, tensor) -> torch.Tensor: - r""" - Gather tensor. - """ + if "state" in self.__dict__: + return self.state.local_process_index + return 0 - return self.accelerator.gather(tensor) - - def reduce(self, tensor, reduction: str = "sum") -> torch.Tensor: - r""" - Reduce tensor. - """ - - return self.accelerator.reduce(tensor, reduction=reduction) - - def __getattr__(self, name: str) -> Any: - with suppress(AttributeError): - return super().__getattr__(name) - if "accelerator" in self.__dict__ and hasattr(self.accelerator, name): - return getattr(self.accelerator, name) - raise super().__getattribute__(name) + @cached_property + def accum_steps(self) -> int: + return self.gradient_accumulation_steps + + def get_accelerate_config(self, config) -> FlatDict: + accelerate = FlatDict() + if "accelerate" in config: + accelerate.update(config.accelerate) + if "precision" in config: + accelerate.mixed_precision = config.precision + if "dynamo" in config: + accelerate.dynamo_backend = config.dynamo.upper() + if "accum_steps" in config: + accelerate.gradient_accumulation_steps = config.accum_steps + if "kwargs_handlers" not in accelerate: + accelerate.kwargs_handlers = [] + # Must NOT set project_dir here as timestamp is not synced yet + # config.project_dir = self.dir + if os.getenv("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true": + deepspeed_config = config.get("deepspeed", os.getenv("ACCELERATE_DEEPSPEED_CONFIG_FILE")) + accelerate.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.get_deepspeed_config(deepspeed_config)) + return accelerate + + def build_dataloaders(self): + datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders} + default_kwargs = self.config.setdefault("dataloader", NestedDict()) + dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs}) + for k, d in datasets.items(): + dataloader_kwargs.setdefault(k, NestedDict()) + dataloader_kwargs[k].merge(default_kwargs, overwrite=False) + dataloader_kwargs[k].setdefault("shuffle", getattr(d, "train", True)) + dataloader_kwargs[k].setdefault("drop_last", not getattr(d, "train", True)) + self.dataloaders[k] = utils.data.DataLoader(d, collate_fn=self.collate_fn, **dataloader_kwargs[k]) + default_kwargs.update(dataloader_kwargs) + for k, d in self.dataloaders.items(): + self.dataloaders[k] = self.prepare(d) diff --git a/danling/runner/base_runner.py b/danling/runner/base_runner.py index 38878256..a8ad8e3d 100644 --- a/danling/runner/base_runner.py +++ b/danling/runner/base_runner.py @@ -26,11 +26,11 @@ from collections.abc import Callable, Mapping, Sequence from math import ceil from sys import version_info -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid5 from warnings import warn -from chanfig import Config, FlatDict, NestedDict, Variable +from chanfig import FlatDict, NestedDict, Variable from danling.metrics import AverageMeter, AverageMeters, MetricMeters from danling.typing import File, PathStr @@ -49,11 +49,11 @@ except ImportError: np_random = None -from .state import RunnerState -from .utils import RunnerMeta, RunnerMode, get_time_str, on_main_process +from .config import Config +from .utils import RunnerMeta, RunnerMode, format_result, get_time_str, on_main_process PY38_PLUS = version_info >= (3, 8) -IGNORED_SET_NAMES = ("index", "epoch", "step", "iter") +IGNORED_SET_NAMES = ("index", "epochs", "steps", "iters") __APPEND_RESULT_COUNTER__ = 0 @@ -63,19 +63,19 @@ class BaseRunner(metaclass=RunnerMeta): # pylint: disable=too-many-public-metho `BaseRunner` sets up basic running environment, including `seed`, `deterministic`, and `logging`. - `BaseRunner` also provides some basic methods, such as, `step`, `state_dict`, `save_checkpoint`, `load_checkpoint`. + `BaseRunner` also provides some basic methods, such as, `steps`, `state_dict`, `save_checkpoint`, `load_checkpoint`. `BaseRunner` defines all basic attributes and relevant properties such as `scores`, `progress`, etc. Attributes: ID: timestamp (str): A time string representing the creation time of run. - name (str): `f"{self.state.experiment_name}-{self.state.run_name}"`. - id (str): `f"{self.state.experiment_id:.8}{self.state.run_id:.8}"`. - uuid (UUID, property): `uuid5(self.state.run_id, self.id)`. + name (str): `f"{self.config.experiment_name}-{self.config.run_name}"`. + id (str): `f"{self.config.experiment_id:.8}{self.config.run_id:.8}"`. + uuid (UUID, property): `uuid5(self.config.run_id, self.id)`. Attributes: Core: mode (RunnerMode, property): Running mode. - state (RunnerState): Running state. See `RunnerState` for details. + config (Config): Running config. See [`Config`] for details. Attributes: Model: model (Callable): @@ -162,7 +162,7 @@ class BaseRunner(metaclass=RunnerMeta): # pylint: disable=too-many-public-metho writer: See Also: - [`RunnerState`][danling.runner.runner_state.RunnerState]: The runeer base that stores runtime information. + [`Config`][danling.runner.Config]: The runeer base that stores runtime information. [`BaseRunner`][danling.runner.BaseRunner]: The base runner class. """ @@ -171,10 +171,11 @@ class BaseRunner(metaclass=RunnerMeta): # pylint: disable=too-many-public-metho timestamp: str _mode: RunnerMode - _state: RunnerState + _config: Config inited: bool = False model: Callable | None = None + ema: Callable | None = None criterion: Callable | None = None optimizer: Any | None = None scheduler: Any | None = None @@ -182,15 +183,17 @@ class BaseRunner(metaclass=RunnerMeta): # pylint: disable=too-many-public-metho datasets: FlatDict datasamplers: FlatDict dataloaders: FlatDict - split: str + split: str | None = None results: NestedDict meters: AverageMeters metrics: Metrics | MetricMeters | None = None + train_metrics: Metrics | MetricMeters | None = None + evaluate_metrics: Metrics | MetricMeters | None = None logger: logging.Logger | None = None writer: Any | None = None - def __init__(self, config: NestedDict) -> None: + def __init__(self, config: Config) -> None: self.timestamp = get_time_str() if "datasets" not in self.__dict__: self.datasets = FlatDict() @@ -202,18 +205,24 @@ def __init__(self, config: NestedDict) -> None: self.results = NestedDict() self.meters = AverageMeters() self._mode = RunnerMode.train # type: ignore[assignment] - # must init state at last to avoid name conflicts - self._state = RunnerState(config) - self.inited = True + # must init config at last to avoid name conflicts + if not isinstance(config, Config): + config = Config(config) + self._config = config self.init_distributed() - if self.state.seed is not None: + self.inited = True + if "checkpoint" in config: + self.load_config(config["checkpoint"]) + + def __post_init__(self): + if self.config.seed is not None: self.set_seed() - if self.state.deterministic: + if self.config.deterministic: self.set_deterministic() - if self.state.log: + if self.config.log: self.init_logging() self.init_print() - if self.state.tensorboard: + if self.config.tensorboard: self.init_tensorboard() def init_distributed(self) -> None: @@ -221,88 +230,6 @@ def init_distributed(self) -> None: Initialise distributed running environment. """ - raise NotImplementedError - - def init_deepspeed( # pylint: disable=too-many-branches, too-many-statements - self, config: Dict = None # type: ignore - ) -> Dict: - r""" - Preprocess DeepSpeed config. - """ - - if config is None: - config = self.state.get("deepspeed") - if config is None: - return {} - if isinstance(config, str): - config = NestedDict.load(config) - if config.get("steps_per_print", "auto") == "auto": - config["steps_per_print"] = self.state.log_interval - if config.get("train_micro_batch_size_per_gpu", "auto") == "auto": - config["train_micro_batch_size_per_gpu"] = self.batch_size - if "amp" in config: - amp = config["amp"] - if amp.get("enabled", "auto") == "auto": - amp["enabled"] = "true" - if amp.get("opt_level", "auto") == "auto": - amp["opt_level"] = "O1" - if "zero_optimization" in config: - zero = config["zero_optimization"] - if zero.get("allgather_bucket_size") == "auto": - zero["allgather_bucket_size"] = 1e6 - if zero.get("reduce_bucket_size") == "auto": - zero["reduce_bucket_size"] = 1e6 - if zero.get("stage3_max_live_parameters") == "auto": - zero["stage3_max_live_parameters"] = 1e8 - if zero.get("stage3_max_live_gradients") == "auto": - zero["stage3_max_live_gradients"] = 1e8 - if zero.get("stage3_max_reuse_distance") == "auto": - zero["stage3_max_reuse_distance"] = 1e8 - if zero.get("stage3_prefetch_bucket_size") == "auto": - zero["stage3_prefetch_bucket_size"] = 1e6 - if zero.get("stage3_param_persistence_threshold") == "auto": - zero["stage3_param_persistence_threshold"] = 1e8 - if "amp" in config: - if "fp16" not in config: - config["fp16"] = {} - if config["fp16"].get("enabled", "auto"): - config["fp16"]["enabled"] = config["amp"]["enabled"] - warn( - f"AMP is not compatible with ZeRO. Automatically set 'fp16' to {config['amp']['enabled']}", - stacklevel=2, - ) - del config["amp"] - if "optimizer" in config: - if "params" not in config["optimizer"]: - config["optimizer"]["params"] = {} - optimizer = config["optimizer"]["params"] - if optimizer.get("lr", "auto") == "auto": - optimizer["lr"] = self.state.get("optim.lr", 1e-3) - if optimizer.get("weight_decay", "auto") == "auto": - optimizer["weight_decay"] = self.state.get("optim.weight_decay", 1e-2) - if optimizer.get("betas") == "auto": - optimizer["betas"] = (0.9, 0.999) - if optimizer.get("eps") == "auto": - optimizer["eps"] = 1e-8 - if "scheduler" in config: - if "params" not in config["scheduler"]: - config["scheduler"]["params"] = {} - scheduler = config["scheduler"]["params"] - if scheduler.get("total_num_steps", "auto") == "auto": - scheduler["total_num_steps"] = self.total_steps - if scheduler.get("warmup_num_steps", "auto") == "auto": - scheduler["warmup_num_steps"] = scheduler["total_num_steps"] // 20 - if scheduler.get("warmup_max_lr", "auto") == "auto": - if self.optimizer: - scheduler["warmup_max_lr"] = self.optimizer.param_groups[0]["lr"] - elif "optimizer" in config: - scheduler["warmup_max_lr"] = config["optimizer"]["params"]["lr"] - else: - raise ValueError("warmup_max_lr is not defined and cannot be inferred") - if scheduler.get("warmup_min_lr", "auto") == "auto": - scheduler["warmup_min_lr"] = 1e-7 - return config - @on_main_process def init_logging(self) -> None: r""" @@ -357,7 +284,7 @@ def init_print(self, process: int = 0) -> None: Notes ----- - If `self.state.log = True`, the default `print` function will be override by `logging.info`. + If `self.config.log = True`, the default `print` function will be override by `logging.info`. """ logger = logging.getLogger("print") @@ -369,7 +296,7 @@ def init_print(self, process: int = 0) -> None: @catch def print(*args, force=False, end="\n", file=None, flush=False, **kwargs): # pylint: disable=redefined-builtin if self.rank == process or force: - if self.state.log: + if self.config.log: if not args: args = [""] logger.info(*args, **kwargs) @@ -385,13 +312,13 @@ def init_tensorboard(self, *args, **kwargs) -> None: """ raise NotImplementedError - def set_seed(self, seed: int | None = None, bias: int | None = None) -> None: + def set_seed(self, seed: int = None, bias: int = None) -> int: # type: ignore[assignment] r""" Set up random seed. Args: seed: Random seed to set. - Defaults to `self.state.seed` (`config.seed`). + Defaults to `self.config.seed` (`config.seed`). bias: Make the seed different for each processes. @@ -400,15 +327,18 @@ def set_seed(self, seed: int | None = None, bias: int | None = None) -> None: Defaults to `self.rank`. Set to `False` to disable this feature. + Returns: + Random seed set. """ - seed = seed or self.state.seed + seed = seed or self.config.seed # type: ignore[assignment] bias = bias or self.rank if bias: seed += bias if np_random is not None: np_random.seed(seed) random.seed(seed) + return seed def set_deterministic(self) -> None: r""" @@ -427,8 +357,8 @@ def scale_lr( Scale learning rate according to [linear scaling rule](https://arxiv.org/abs/1706.02677). """ - if lr_scale_factor in self.state: - lr_scale_factor = self.state.lr_scale_factor + if lr_scale_factor in self.config: + lr_scale_factor = self.config.lr_scale_factor if lr_scale_factor is None: if batch_size_base is None: @@ -441,19 +371,15 @@ def scale_lr( "batch_size_base will be ignored if lr_scale_factor is specified", category=RuntimeWarning, stacklevel=2 ) lr = lr * lr_scale_factor - self.state.lr_scale_factor = lr_scale_factor + self.config.lr_scale_factor = lr_scale_factor return lr - def step(self, loss, batch_size: int | None = None, zero_grad: bool = True) -> None: + def advance(self, loss, *args, **kwargs) -> None: r""" Backward loss and step optimizer & scheduler. - This method increment `self.state.steps`. - - This method also increment `self.state.iters` when `batch_size` is specified. - Args: - zero_grad: Whether to zero the gradients. + loss: loss. """ raise NotImplementedError @@ -463,17 +389,17 @@ def state_dict(self, cls: Callable = dict) -> Mapping: Return dict of all attributes for checkpoint. """ - return cls(self.state) + return cls(self.config) def dict(self, cls: Callable = dict) -> Mapping: r""" - Convert state to Mapping. + Convert config to Mapping. Args: cls: Target `clc to convert to. """ - return self.state.dict(cls) + return self.config.dict(cls) @catch def save(self, obj: Any, file: PathStr, main_process_only: bool = True, *args, **kwargs) -> File: @@ -502,11 +428,11 @@ def load(file: PathStr, *args, **kwargs) -> Any: @catch def json(self, file: File, main_process_only: bool = True, *args, **kwargs) -> None: # pylint: disable=R1710 r""" - Dump Runner State to json file. + Dump Runner config to json file. """ if main_process_only and self.is_main_process or not main_process_only: - return self.state.json(file, *args, **kwargs) + return self.config.json(file, *args, **kwargs) @classmethod def from_json(cls, file: File, *args, **kwargs) -> BaseRunner: @@ -522,10 +448,10 @@ def from_json(cls, file: File, *args, **kwargs) -> BaseRunner: def jsons(self, *args, **kwargs) -> str: r""" - Dump Runner State to json string. + Dump Runner config to json string. """ - return self.state.jsons(*args, **kwargs) + return self.config.jsons(*args, **kwargs) @classmethod def from_jsons(cls, string: str, *args, **kwargs) -> BaseRunner: @@ -538,11 +464,11 @@ def from_jsons(cls, string: str, *args, **kwargs) -> BaseRunner: @catch def yaml(self, file: File, main_process_only: bool = True, *args, **kwargs) -> None: # pylint: disable=R1710 r""" - Dump Runner State to yaml file. + Dump Runner config to yaml file. """ if main_process_only and self.is_main_process or not main_process_only: - return self.state.yaml(file, *args, **kwargs) + return self.config.yaml(file, *args, **kwargs) @classmethod def from_yaml(cls, file: File, *args, **kwargs) -> BaseRunner: @@ -558,10 +484,10 @@ def from_yaml(cls, file: File, *args, **kwargs) -> BaseRunner: def yamls(self, *args, **kwargs) -> str: r""" - Dump Runner State to yaml string. + Dump Runner config to yaml string. """ - return self.state.yamls(*args, **kwargs) + return self.config.yamls(*args, **kwargs) @classmethod def from_yamls(cls, string: str, *args, **kwargs) -> BaseRunner: @@ -596,102 +522,109 @@ def check_dir(self, action: str = "warn") -> bool: @catch @on_main_process - def save_checkpoint(self, epochs: int | None = None) -> None: + def save_checkpoint(self, name: str = "latest", epochs: int | None = None, save_best: bool = True) -> None: r""" Save checkpoint to `self.checkpoint_dir`. - The checkpoint will be saved to `self.checkpoint_dir/latest.pth`. - - If `self.state.save_interval` is positive and `self.state.epochs + 1` is a multiple of `save_interval`, - the checkpoint will also be copied to `self.checkpoint_dir/epoch-{self.state.epochs}.pth`. + Args: + name: Name of the checkpoint. Defaults to `"latest"`. + epoch: Epoch to save. Defaults to `self.config.epochs`. + save_best: If `True`, when `self.is_best` is `True`, the checkpoint will also be copied to + `self.checkpoint_dir/best`. - If `self.is_best` is `True`, the checkpoint will also be copied to `self.checkpoint_dir/best.pth`. + If `self.config.save_interval` is positive and `epochs + 1` is a multiple of `save_interval`, + the checkpoint will also be copied to `self.checkpoint_dir/epoch-{epochs}.pth`. """ - epochs = epochs or self.state.epochs - save_interval = self.state.get("save_interval", -1) - latest_path = os.path.join(self.checkpoint_dir, "latest.pth") + epochs = epochs or self.config.epochs + save_interval = self.config.get("save_interval", -1) + latest_path = os.path.join(self.checkpoint_dir, f"{name}.pth") self.save(self.state_dict(), latest_path) if save_interval > 0 and (epochs + 1) % save_interval == 0: save_path = os.path.join(self.checkpoint_dir, f"epoch-{epochs}.pth") shutil.copy(latest_path, save_path) - if self.is_best: + if save_best and self.is_best: best_path = os.path.join(self.checkpoint_dir, "best.pth") shutil.copy(latest_path, best_path) - def load_checkpoint( - self, - checkpoint: Mapping | bytes | str | os.PathLike | None = None, - auto_resume: bool | None = None, - override_state: bool = False, - *args, - **kwargs, + def load_config( + self, checkpoint: Mapping | bytes | str | os.PathLike, overwrite: bool = False, *args, **kwargs ) -> None: - """ - Load info from checkpoint. + r""" + Load config from checkpoint. Args: checkpoint: Checkpoint (or its path) to load. - Defaults to `self.state.checkpoint`. - auto_resume: Automatically resume from latest checkpoint if exists. - Defaults to `False`. - If is `True` and `checkpoint` is None, will set it to `self.checkpoint_dir/latest.pth`. - override_state: If True, override runner state with checkpoint state. + overwrite: If `True`, overwrite the current config with the loaded config. Defaults to `False`. *args: Additional arguments to pass to `self.load`. **kwargs: Additional keyword arguments to pass to `self.load`. Raises: FileNotFoundError: If `checkpoint` does not exists. + """ + + if isinstance(checkpoint, (bytes, str, os.PathLike)): + if not os.path.exists(checkpoint): + raise FileNotFoundError(f"checkpoint is set to {checkpoint!r} but does not exist") + config = self.load(checkpoint, *args, **kwargs) + elif isinstance(checkpoint, Mapping): + config = checkpoint + else: + raise ValueError(f"checkpoint is set to {checkpoint!r} but is not a valid checkpoint") + + config = config.get("runner", config) + self.config.merge(config, overwrite=overwrite) + self.config.iter_begin = config["iters"] + 1 + self.config.step_begin = config["steps"] + 1 + self.config.epoch_begin = config["epochs"] + 1 + + def load_checkpoint(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> None: + """ + Load model, optimizer, and scheduler from checkpoint. + + Args: + checkpoint: Checkpoint (or its path) to load. + *args: Additional arguments to pass to `self.load`. + **kwargs: Additional keyword arguments to pass to `self.load`. + + Raises: + ValueError: If `model` is not defined. + ValueError: If `checkpoint` is not a valid checkpoint. + FileNotFoundError: If `checkpoint` does not exists. See Also: [`from_checkpoint`][danling.BaseRunner.from_checkpoint]: Build runner from checkpoint. - [`load_pretrained`][danling.BaseRunner.load_pretrained]: Load parameters from pretrained checkpoint. + [`load_pretrained`][danling.BaseRunner.load_pretrained]: Load model parameters from pretrained checkpoint. """ - checkpoint = checkpoint if checkpoint is not None else self.state.get("checkpoint") - auto_resume = auto_resume if auto_resume is not None else self.state.get("auto_resume", False) - - # TODO: Support loading checkpoints in other format - if checkpoint is not None: - if auto_resume: - warn( - "latest checkpoint is preempted by value specified in checkpoint", - RuntimeWarning, - stacklevel=2, - ) - if isinstance(checkpoint, (bytes, str, os.PathLike)): - if not os.path.exists(checkpoint): - raise FileNotFoundError(f"checkpoint is set to {checkpoint!r} but does not exist") - self.state.checkpoint = checkpoint - ckpt = self.load(checkpoint, *args, **kwargs) - elif isinstance(checkpoint, Mapping): - ckpt = checkpoint + if self.model is None: + raise ValueError("model is not defined") + if isinstance(checkpoint, (bytes, str, os.PathLike)): + if not os.path.exists(checkpoint): + raise FileNotFoundError(f"checkpoint is set to {checkpoint!r} but does not exist") + ckpt = self.load(checkpoint, *args, **kwargs) + elif isinstance(checkpoint, Mapping): + ckpt = checkpoint + else: + raise ValueError(f"checkpoint is set to {checkpoint!r} but is not a valid checkpoint") + + state_dict = ckpt + while "model" in state_dict or "module" in state_dict: + state_dict = state_dict.get("model", state_dict) + state_dict = state_dict.get("module", state_dict) + self.unwrap(self.model).load_state_dict(state_dict) + if self.optimizer is not None: + if "optimizer" in ckpt: + self.optimizer.load_state_dict(ckpt["optimizer"]) else: - raise ValueError(f"pretrained is set to {checkpoint!r} but is not a valid checkpoint") - elif auto_resume: - checkpoint = os.path.join(self.checkpoint_dir, "latest.pth") - if os.path.exists(checkpoint): - self.state.checkpoint = checkpoint - ckpt = self.load(checkpoint, *args, **kwargs) + warn("optimizer is not in checkpoint", category=RuntimeWarning, stacklevel=2) + if self.scheduler is not None: + if "scheduler" in ckpt: + self.scheduler.load_state_dict(ckpt["scheduler"]) else: - warn("latest checkpoint does not exits", category=RuntimeWarning, stacklevel=2) - return - else: - raise ValueError("checkpoint is not specified and auto_resume is not set to True") - - # TODO: Wrap state_dict in a dataclass - self.state.merge(ckpt["runner"], overwrite=override_state) - if self.model is not None and "model" in ckpt: - model = self.unwrap_model(self.model) - model.load_state_dict(ckpt["model"]) - if self.optimizer is not None and "optimizer" in ckpt: - self.optimizer.load_state_dict(ckpt["optimizer"]) - if self.scheduler is not None and "scheduler" in ckpt: - self.scheduler.load_state_dict(ckpt["scheduler"]) - self.state.iter_begin = self.state.iters + 1 - self.state.step_begin = self.state.steps + 1 - self.state.epoch_begin = self.state.epochs + 1 + warn("scheduler is not in checkpoint", category=RuntimeWarning, stacklevel=2) + self.config.checkpoint = checkpoint @classmethod def from_checkpoint(cls, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> BaseRunner: @@ -713,33 +646,33 @@ def from_checkpoint(cls, checkpoint: Mapping | bytes | str | os.PathLike, *args, ckpt = checkpoint else: raise ValueError(f"checkpoint is set to {checkpoint} but is not a valid checkpoint") - runner = cls(**ckpt["runner"]) - runner.load_checkpoint(ckpt, override_state=False) + runner = cls(ckpt["runner"]) + runner.load_checkpoint(ckpt, override_config=False) return runner - def load_pretrained(self, checkpoint: Mapping | bytes | str | os.PathLike | None = None, *args, **kwargs) -> None: + def load_pretrained(self, checkpoint: Mapping | bytes | str | os.PathLike, *args, **kwargs) -> None: """ - Load parameters from pretrained checkpoint. + Load model from pretrained checkpoint. This method only loads the model weights. Args: checkpoint: Pretrained checkpoint (or its path) to load. - Defaults to `self.state.pretrained`. *args: Additional arguments to pass to `self.load`. **kwargs: Additional keyword arguments to pass to `self.load`. Raises: + ValueError: If `model` is not defined. + ValueError: If `checkpoint` is not a valid checkpoint. FileNotFoundError: If `checkpoint` does not exists. See Also: - [`load_checkpoint`][danling.BaseRunner.load_checkpoint]: Load info from checkpoint. + [`load_checkpoint`][danling.BaseRunner.load_checkpoint]: Load model, optimizer, and scheduler from + checkpoint. """ - # TODO: Support loading checkpoints in other format - checkpoint = checkpoint if checkpoint is not None else self.state.get("pretrained") - if checkpoint is None: - raise ValueError("pretrained is not specified") + if self.model is None: + raise ValueError("model is not defined") if isinstance(checkpoint, (bytes, str, os.PathLike)): if not os.path.exists(checkpoint): raise FileNotFoundError(f"pretrained is set to {checkpoint!r} but does not exist") @@ -748,11 +681,32 @@ def load_pretrained(self, checkpoint: Mapping | bytes | str | os.PathLike | None ckpt = checkpoint else: raise ValueError(f"pretrained is set to {checkpoint!r} but is not a valid checkpoint") - if self.model is not None and "model" in ckpt: - model = self.unwrap_model(self.model) - model.load_state_dict(ckpt["model"]) - else: - raise ValueError(f"Unable to find model weights in {checkpoint!r}") + + state_dict = ckpt + while "model" in state_dict or "module" in state_dict: + state_dict = state_dict.get("model", state_dict) + state_dict = state_dict.get("module", state_dict) + self.unwrap(self.model).load_state_dict(state_dict) + + def get_step_result(self) -> NestedDict: + result = self.meters.value() + if self.metrics is not None: + return self._merge_result(result, self.metrics.value()) + return result + + def get_epoch_result(self) -> NestedDict: + result = self.meters.average() + if self.metrics is not None: + return self._merge_result(result, self.metrics.average()) + return result + + def _merge_result(self, result, metric_result) -> NestedDict: + for key, value in metric_result.items(): + if isinstance(value, (Mapping)) and len(value) == 1: + value = next(iter(value.values())) + metric_result[key] = value + result.update(metric_result) + return result def append_result(self, result: NestedDict, index: int | None = None) -> None: r""" @@ -765,14 +719,14 @@ def append_result(self, result: NestedDict, index: int | None = None) -> None: """ if index is None: - index = self.state.epochs + index = self.config.epochs global __APPEND_RESULT_COUNTER__ # pylint: disable=global-statement __APPEND_RESULT_COUNTER__ += 1 if index == 0 and __APPEND_RESULT_COUNTER__ > 1: warn( """ - Automatically set index to `self.state.epochs`. - Please ensure `self.state.epochs` updates before calling `append_result` + Automatically set index to `self.config.epochs`. + Please ensure `self.config.epochs` updates before calling `append_result` """, category=RuntimeWarning, stacklevel=2, @@ -793,9 +747,7 @@ def print_result(self) -> None: def step_log(self, split: str, iteration: int, length: int | None = None): if length is None: length = len(self.dataloaders[split]) - 1 - result = self.meters.val - if self.metrics is not None: - result.merge(self.metrics.val) + result = self.get_step_result() print(self.format_step_result(result, split, iteration, length)) if self.mode == "train": self.write_result(result, split) @@ -804,7 +756,6 @@ def step_log(self, split: str, iteration: int, length: int | None = None): def format_step_result( self, result: NestedDict, split: str, steps: int, length: int, format_spec: str = ".4f" ) -> str: - result = NestedDict(result).clone() repr_str = "" if split is not None: if self.mode == "train": @@ -819,15 +770,13 @@ def format_step_result( def format_epoch_result( self, result: NestedDict, epochs: int | None = None, epoch_end: int | None = None, format_spec: str = ".4f" ) -> str: - result = NestedDict(result).clone() - epochs = epochs or self.state.epochs - epoch_end = epoch_end or self.state.epoch_end - repr_str = f"epoch [{epochs}/{epoch_end - 1}]\n" if epochs is not None and epoch_end else "" - repr_str += "\n".join([f"{k}:\t{self.format_result(v, format_spec=format_spec)}" for k, v in result.items()]) - return repr_str + epochs = epochs or self.config.epochs + epoch_end = epoch_end or self.config.epoch_end + repr_str = f"epoch [{epochs}/{epoch_end - 1}]" if epochs is not None and epoch_end else "" + return repr_str + self.format_result(result, format_spec=format_spec) - def format_result(self, result, format_spec: str = ".4f") -> str: - return "\t".join([f"{k}: {format(v, format_spec)}" for k, v in result.items()]) + def format_result(self, result: Mapping, format_spec: str = ".4f") -> str: + return format_result(result, format_spec=format_spec) def write_result(self, result: NestedDict, split: str, steps: int | None = None): if steps is None: @@ -883,20 +832,23 @@ def save_result(self) -> None: best_path = os.path.join(self.dir, "best.json") shutil.copy(latest_path, best_path) + def unwrap(self, model: Any) -> Any: + return model + @cached_property def name(self): - if "name" in self._state: - return self.state["name"] - return f"{self.state.experiment_name}-{self.state.run_name}" + if "name" in self.config: + return self.config["name"] + return f"{self.config.experiment_name}-{self.config.run_name}" @cached_property def id(self): - return f"{self.state.experiment_id:.8}{self.state.run_id:.8}" + return f"{self.config.experiment_id:.8}{self.config.run_id:.8}" @cached_property def uuid(self) -> UUID: r""" - UUID of the state. + UUID of the config. """ return uuid5(self.run_uuid, self.id) @@ -912,8 +864,8 @@ def mode(self, mode: str | RunnerMode) -> None: self._mode = mode @property - def state(self) -> RunnerState: - return self._state + def config(self) -> Config: + return self._config @property def batch_size(self) -> int: @@ -930,7 +882,7 @@ def batch_size(self) -> int: if self.dataloaders and self.split: return self.dataloaders[self.split].batch_size - batch_size = self.state.get("dataloader.batch_size") + batch_size = self.config.get("dataloader.batch_size") if batch_size: return batch_size raise AttributeError("batch_size could not be inferred and is not in config") @@ -947,32 +899,78 @@ def batch_size_equivalent(self) -> int: return self.batch_size * self.world_size * self.accum_steps @cached_property - def total_epochs(self) -> int: - if self.state.epoch_end: - return self.state.epoch_end - self.state.epoch_begin + 1 - raise ValueError("epoch_end is not specified") + def total_iters(self) -> int: + r""" + Number of training iters. + + An iter is defined by model forward and backward. + + Returns: + (int): + + See Also: + [`total_iters`][]: Number of training iters. + [`total_steps`][]: Number of training steps. + """ + if self.config.iter_end: + return self.config.iter_end - self.config.iter_begin + if "train" not in self.datasets: + return 0 + return self.total_epochs * ceil(len(self.datasets["train"]) / self.batch_size / self.world_size) @cached_property def total_steps(self) -> int: - if self.state.step_end: - return self.state.step_end - self.state.step_begin - dataset = self.datasets.get("train", next(iter(self.datasets.values()))) - return self.total_epochs * ceil(len(dataset) / self.batch_size / self.world_size) + r""" + Number of training steps. + + A step is defined by optimizer update. + + `total_steps` is equivalent to `total_iters` divided by `accum_steps`. + + Returns: + (int): + + See Also: + [`total_iters`][]: Number of training iters. + [`total_steps`][]: Number of training steps. + [`total_epochs`][]: Number of training epochs. + """ + if self.config.step_end: + return self.config.step_end - self.config.step_begin + return ceil(self.total_iters / self.accum_steps) @cached_property - def trainable_steps(self) -> int: - return ceil(self.total_steps / self.accum_steps) + def total_epochs(self) -> int: + r""" + Number of training epochs. + + An epoch is defined by one pass of the dataset. + + Returns: + (int): + + See Also: + [`total_iters`][]: Number of training iters. + [`total_steps`][]: Number of training steps. + """ + if self.config.epoch_end: + return self.config.epoch_end - self.config.epoch_begin + raise ValueError("epoch_end is not specified") @cached_property def accum_steps(self) -> int: r""" - Accumulated steps. + Number of steps to accumulate gradients. Returns: (int): + + See Also: + [`total_iters`][]: Number of training iters. + [`total_steps`][]: Number of training steps. """ - return self.state.get("accum_steps", 1) + return self.config.get("accum_steps", 1) @property def progress(self) -> float: @@ -986,15 +984,7 @@ def progress(self) -> float: RuntimeError: If no terminal is defined. """ - return self.steps / self.total_steps - - @property - def device(self) -> Any: - r""" - Device of runner. - """ - - return "cpu" + return self.config.steps / self.total_steps @property def world_size(self) -> int: @@ -1049,7 +1039,7 @@ def best_fn(self) -> Callable: r""" Function to determine the best score from a list of scores. - By default, the `best_fn` returns `min` if `self.state.score_name` is `loss`, + By default, the `best_fn` returns `min` if `self.config.score_name` is `loss`, otherwise, returns `max`. Subclass can override this method to accommodate needs, such as `min`. @@ -1058,7 +1048,7 @@ def best_fn(self) -> Callable: (callable): """ - return max if self.state.score_name != "loss" else min + return max if self.config.score_name != "loss" else min @property def best_index(self) -> int: @@ -1104,13 +1094,13 @@ def scores(self) -> FlatDict | None: r""" All scores. - Scores are extracted from results by `score_split` and `runner.state.score_name`, - following `[r[score_split][self.state.score_name] for r in self.results]`. + Scores are extracted from results by `score_split` and `runner.config.score_name`, + following `[r[score_split][self.config.score_name] for r in self.results]`. Scores are considered as the index of the performance of the model. It is useful to determine the best model and the best hyper-parameters. - `score_split` is defined in `self.state.score_split`. + `score_split` is defined in `self.config.score_split`. If it is not set, `DanLing` will use `val` or `validate` if they appear in the `latest_result`. If `DanLing` still could not find, it will fall back to the second key in the `latest_result` if it contains more that one element, or the first key. @@ -1121,14 +1111,14 @@ def scores(self) -> FlatDict | None: if not self.results: return None subsets = [i for i in self.latest_result.keys() if i not in IGNORED_SET_NAMES] # type: ignore - score_split = self.state.get("score_split") + score_split = self.config.get("score_split") if score_split is None and "val" in subsets: score_split = "val" if score_split is None and "validate" in subsets: score_split = "validate" if score_split is None: score_split = subsets[1] if len(subsets) > 1 else subsets[0] - return FlatDict({k: v[score_split][self.state.score_name] for k, v in self.results.items()}) + return FlatDict({k: v[score_split][self.config.score_name] for k, v in self.results.items()}) @property def latest_score(self) -> float | None: @@ -1172,8 +1162,8 @@ def dir(self) -> str: Directory of the run. """ - if "dir" in self.state: - return self.state.dir + if "dir" in self.config: + return self.config.dir return os.path.join(self.project_root, f"{self.name}-{self.id}", self.timestamp) @cached_property @@ -1182,8 +1172,8 @@ def log_path(self) -> str: Path of log file. """ - if "log_path" in self.state: - return self.state.log_path + if "log_path" in self.config: + return self.config.log_path return os.path.join(self.dir, "run.log") @property @@ -1193,9 +1183,9 @@ def checkpoint_dir(self) -> str: Directory of checkpoints. """ - if "checkpoint_dir" in self.state: - return self.state.checkpoint_dir - return os.path.join(self.dir, self.state.checkpoint_dir_name) + if "checkpoint_dir" in self.config: + return self.config.checkpoint_dir + return os.path.join(self.dir, self.config.checkpoint_dir_name) # def __getattribute__(self, name) -> Any: # if name in ("__class__", "__dict__"): @@ -1204,16 +1194,16 @@ def checkpoint_dir(self) -> str: # return self.__dict__[name] # if name in dir(self): # return super().__getattribute__(name) - # if "state" in self and name in self.state: - # return self.state[name] + # if "config" in self and name in self.config: + # return self.config[name] # return super().__getattribute__(name) def __getattr__(self, name) -> Any: if self.inited: - if name in self._state: - return self.state[name] - if name in dir(self.state): - return getattr(self.state, name) + if name in self.config: + return self.config[name] + if name in dir(self.config): + return getattr(self.config, name) return super().__getattribute__(name) def __setattr__(self, name, value) -> None: @@ -1230,19 +1220,19 @@ def __setattr__(self, name, value) -> None: object.__setattr__(self, name, value) return if self.inited: - if name in self.state: - if isinstance(self.state[name], Variable): - self.state[name].set(value) + if name in self.config: + if isinstance(self.config[name], Variable): + self.config[name].set(value) else: - self.state[name] = value + self.config[name] = value return - if name in dir(self.state): - setattr(self.state, name, value) + if name in dir(self.config): + setattr(self.config, name, value) return object.__setattr__(self, name, value) def __contains__(self, name) -> bool: - return name in dir(self) or ("state" in self.__dict__ and name in dir(self.state)) + return name in dir(self) or ("config" in self.__dict__ and name in dir(self.config)) def __repr__(self): lines = [] diff --git a/danling/runner/state.py b/danling/runner/config.py similarity index 71% rename from danling/runner/state.py rename to danling/runner/config.py index c59bfc44..f942bf52 100755 --- a/danling/runner/state.py +++ b/danling/runner/config.py @@ -17,27 +17,27 @@ from __future__ import annotations -from random import randint from typing import Optional from uuid import UUID, uuid5 -from chanfig import NestedDict +import chanfig + +from danling import defaults -from . import defaults from .utils import get_git_hash -class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes +class Config(chanfig.Config): # pylint: disable=too-many-instance-attributes r""" - `RunnerState` is a `NestedDict` that contains all states of a `Runner`. + `Config` is a [`Config`][chanfig.Config] that contains all states of a `Runner`. - `RunnerState` is designed to store all critical information of a Run so that you can resume a run + `Config` is designed to store all critical information of a Run so that you can resume a run from a state and corresponding weights or even restart a run from a state. - `RunnerState` is also designed to be serialisable and hashable, so that you can save it to a file. - `RunnerState` is saved in checkpoint together with weights by default. + `Config` is also designed to be serialisable and hashable, so that you can save it to a file. + `Config` is saved in checkpoint together with weights by default. - Since `RunnerState` is a [`NestedDict`][chanfig.NestedDict], you can access its attributes by + Since `Config` is a [`Config`][chanfig.Config], you can access its attributes by `state["key"]` or `state.key`. Attributes: General: @@ -59,13 +59,13 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes Attributes: Progress: iters (int): The number of data samples processed. equals to `steps` when `batch_size = 1`. - steps (int): The number of `step` calls. + steps (int): The number of `steps` calls. epochs (int): The number of complete passes over the datasets. - iter_end (int): End running iters. + iter_end (int): End running iter. Note that `step_end` not initialised since this variable may not apply to some Runners. - step_end (int): End running steps. + step_end (int): End running step. Note that `step_end` not initialised since this variable may not apply to some Runners. - epoch_end (int): End running epochs. + epoch_end (int): End running epoch. Note that `epoch_end` not initialised since this variable may not apply to some Runners. In general you should only use one of `iter_end`, `step_end`, `epoch_end` to indicate the length of running. @@ -96,7 +96,7 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes If <= 0, save only the latest and the best checkpoints. Notes: - `RunnerState` is a `NestedDict`, so you can access its attributes by `state["name"]` or `state.name`. + `Config` is a [`Config`][chanfig.Config], so you can access its attributes by `state["name"]` or `state.name`. See Also: [`BaseRunner`][danling.runner.BaseRunner]: The base runner class. @@ -104,12 +104,12 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes # DO NOT set default value in class, as they won't be stored in `__dict__`. - run_name: str = defaults.DEFAULT_RUN_NAME + run_name: str = defaults.RUN_NAME run_id: str - experiment_name: str = defaults.DEFAULT_EXPERIMENT_NAME + experiment_name: str = defaults.EXPERIMENT_NAME experiment_id: str - seed: int + seed: Optional[int] = None deterministic: bool = False iters: int = 0 @@ -132,24 +132,12 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes log_interval: Optional[int] = None save_interval: Optional[int] = None - distributed: Optional[bool] = None - dist_backend: Optional[str] = None - init_method: Optional[str] = None - master_addr: Optional[str] = None - master_port: Optional[int] = None - - def __init__(self, *args, **kwargs): - for k, v in self.__class__.__dict__.items(): - if not (k.startswith("__") and k.endswith("__")) and (not (isinstance(v, property) or callable(v))): - self.set(k, v) - if "seed" not in self: - self.seed = randint(0, 2**32 - 1) - super().__init__(*args, **kwargs) + def __post_init__(self): if "experiment_id" not in self: - self.experiment_id = get_git_hash() or defaults.DEFAULT_EXPERIMENT_ID + self.experiment_id = get_git_hash() or defaults.EXPERIMENT_ID if "run_id" not in self: self.run_id = self.run_uuid.hex - self.setattr("ignored_keys_in_hash", defaults.DEFAULT_IGNORED_KEYS_IN_HASH) + self.setattr("ignored_keys_in_hash", defaults.IGNORED_CONFIG_IN_HASH) @property def experiment_uuid(self) -> UUID: @@ -165,8 +153,8 @@ def run_uuid(self) -> UUID: UUID of the run. """ - ignored_keys_in_hash = self.getattr("ignored_keys_in_hash", defaults.DEFAULT_IGNORED_KEYS_IN_HASH) - state: NestedDict = NestedDict({k: v for k, v in self.dict().items() if k not in ignored_keys_in_hash}) + ignored_keys_in_hash = self.getattr("ignored_keys_in_hash", defaults.IGNORED_CONFIG_IN_HASH) + state: chanfig.Config = chanfig.Config({k: v for k, v in self.dict().items() if k not in ignored_keys_in_hash}) return uuid5(self.experiment_uuid, state.yamls()) def __hash__(self) -> int: diff --git a/danling/runner/deepspeed_runner.py b/danling/runner/deepspeed_runner.py new file mode 100644 index 00000000..e5606e73 --- /dev/null +++ b/danling/runner/deepspeed_runner.py @@ -0,0 +1,213 @@ +# DanLing +# Copyright (C) 2022-Present DanLing + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the following licenses: +# - The Unlicense +# - GNU Affero General Public License v3.0 or later +# - GNU General Public License v2.0 or later +# - BSD 4-Clause "Original" or "Old" License +# - MIT License +# - Apache License 2.0 + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the LICENSE file for more details. + +from __future__ import annotations + +import os +import shutil + +import torch +from chanfig import NestedDict +from lazy_imports import try_import +from torch import distributed as dist +from torch import nn +from torch.nn.utils import clip_grad_value_ + +from danling.runner.config import Config +from danling.utils import catch + +from .torch_runner import TorchRunner + +with try_import() as ds: + import deepspeed + + +class DeepSpeedRunner(TorchRunner): + + def __init__(self, config: Config) -> None: + ds.check() + super().__init__(config) + + def init_distributed(self) -> None: + r""" + Set up distributed training. + + Initialise process group and set up DDP variables. + """ + + backend = self.config.get("backend", os.getenv("BACKEND")) + init_method = self.config.get("init_method", os.getenv("INIT_METHOD")) + world_size = int(self.config.get("world_size", os.getenv("WORLD_SIZE", "1"))) + rank = int(self.config.get("rank", os.getenv("RANK", "0"))) + if world_size > 1: + if torch.cuda.is_available(): + torch.cuda.set_device(self.get_local_rank()) + deepspeed.init_distributed(dist_backend=backend, init_method=init_method, world_size=world_size, rank=rank) + object_list = [self.id, self.timestamp] + dist.broadcast_object_list(object_list) + self.id, self.timestamp = object_list + + def __post_init__(self): + super().__post_init__() + self.config.deepspeed = self.get_deepspeed_config() + self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.scheduler, + config=self.config.deepspeed, + ) + + def advance(self, loss) -> None: + self.backward(loss) + if self.config.get("max_grad_value") is not None: + clip_grad_value_(self.model.parameters(), self.config["max_grad_value"]) + self.model.step() + if self.ema is not None: + self.ema.update() + self.config.steps = self.model.global_steps + self.config.iters += 1 + + def backward(self, loss: torch.Tensor) -> None: + return self.model.backward(loss) + + def get_local_rank(self) -> int: + local_rank = self.config.get("local_rank", os.getenv("LOCAL_RANK")) + if local_rank is not None: + return int(local_rank) + rank = self.config.get("rank", os.getenv("RANK")) + world_size = self.config.get("world_size", os.getenv("WORLD_SIZE")) + if world_size is None or rank is None: + raise ValueError("Please provide either `local_rank` or `world_size` and `rank`") + return int(world_size) % int(rank) + + def unwrap(self, model: nn.Module) -> nn.Module: + while isinstance(model, (deepspeed.DeepSpeedEngine, nn.parallel.DistributedDataParallel)): + model = model.module + return model + + @property + def deepspeed(self) -> NestedDict | None: + if isinstance(self.model, deepspeed.DeepSpeedEngine): + return self.model.config + return None + + @catch + def save_checkpoint(self, name: str = "latest", epoch: int | None = None, save_best: bool = True) -> None: + r""" + Save checkpoint to `self.checkpoint_dir`. + + Args: + name: Name of the checkpoint. Defaults to `"latest"`. + epoch: Epoch to save. Defaults to `self.config.epochs`. + save_best: If `True`, when `self.is_best` is `True`, the checkpoint will also be copied to + `self.checkpoint_dir/best`. + + If `self.config.save_interval` is positive and `epochs + 1` is a multiple of `save_interval`, + the checkpoint will also be copied to `self.checkpoint_dir/epoch-{epochs}`. + """ + + epoch = epoch or self.config.epochs + save_interval = self.config.get("save_interval", -1) + latest_path = os.path.join(self.checkpoint_dir, name) + os.makedirs(latest_path, exist_ok=True) + self.yaml(os.path.join(latest_path, "runner.yaml")) + self.model.save_checkpoint( + self.checkpoint_dir, tag=name, client_state={"runner": self.config.dict()}, save_latest=False + ) + if save_interval > 0 and (epoch + 1) % save_interval == 0: + save_path = os.path.join(self.checkpoint_dir, f"epoch-{epoch}") + shutil.copytree(latest_path, save_path, dirs_exist_ok=True) + if save_best and self.is_best: + best_path = os.path.join(self.checkpoint_dir, "best") + shutil.copytree(latest_path, best_path, dirs_exist_ok=True) + + def load_checkpoint(self, checkpoint: bytes | str | os.PathLike, *args, **kwargs) -> None: # type: ignore[override] + """ + Load model, optimizer, and scheduler from checkpoint. + + Args: + checkpoint: Checkpoint (or its path) to load. + *args: Additional arguments to pass to `self.load`. + **kwargs: Additional keyword arguments to pass to `self.load`. + + Raises: + ValueError: If `model` is not defined. + ValueError: If `model` is not an instance of `deepspeed.DeepSpeedEngine`. + + See Also: + [`from_checkpoint`][danling.BaseRunner.from_checkpoint]: Build runner from checkpoint. + [`load_pretrained`][danling.BaseRunner.load_pretrained]: Load model parameters from pretrained checkpoint. + """ + + if self.model is None: + raise ValueError("model is not defined") + if not isinstance(self.model, deepspeed.DeepSpeedEngine): + raise ValueError("model is not an instance of `deepspeed.DeepSpeedEngine`") + + self.model.load_checkpoint(checkpoint) + self.config.checkpoint = checkpoint + + def load_pretrained(self, checkpoint: bytes | str | os.PathLike, *args, **kwargs) -> None: # type: ignore[override] + """ + Load model from pretrained checkpoint. + + This method only loads the model weights. + + Args: + checkpoint: Pretrained checkpoint directory. + *args: Additional arguments to pass to `self.load`. + **kwargs: Additional keyword arguments to pass to `self.load`. + + Raises: + ValueError: If `model` is not defined. + + See Also: + [`load_checkpoint`][danling.BaseRunner.load_checkpoint]: Load model, optimizer, and scheduler from + checkpoint. + """ + + if self.model is None: + raise ValueError("model is not defined") + + self.model.load_checkpoint(checkpoint, load_module_only=True) + self.config.pretrained = checkpoint + + def load_config( + self, checkpoint: bytes | str | os.PathLike, overwrite: bool = False, *args, **kwargs # type: ignore[override] + ) -> None: + r""" + Load config from checkpoint. + + Args: + checkpoint: Checkpoint (or its path) to load. + overwrite: If `True`, overwrite the current config with the loaded config. + Defaults to `False`. + *args: Additional arguments to pass to `self.load`. + **kwargs: Additional keyword arguments to pass to `self.load`. + + Raises: + FileNotFoundError: If `checkpoint` does not exists. + """ + + if isinstance(checkpoint, bytes): + checkpoint = checkpoint.decode() + + config = self.load(os.path.join(checkpoint, "runner.yaml"), *args, **kwargs) + self.config.merge(config, overwrite=overwrite) + self.config.iter_begin = config["iters"] + 1 + self.config.step_begin = config["steps"] + 1 + self.config.epoch_begin = config["epochs"] + 1 diff --git a/danling/runner/torch_runner.py b/danling/runner/torch_runner.py index da7ea6b1..eb5eb1c9 100644 --- a/danling/runner/torch_runner.py +++ b/danling/runner/torch_runner.py @@ -15,6 +15,666 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the LICENSE file for more details. -from .accelerate_runner import AccelerateRunner as TorchRunner +from __future__ import annotations -__all__ = ["TorchRunner"] +import os +import random +from collections.abc import Mapping +from contextlib import contextmanager, nullcontext +from math import ceil +from time import time +from typing import Any, Callable, Tuple +from warnings import warn + +import torch +import torch.distributed +from chanfig import NestedDict +from torch import distributed as dist +from torch import nn, optim, utils +from torch.backends import cudnn +from torch.nn.utils import clip_grad_norm_, clip_grad_value_ +from tqdm import tqdm + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property # type: ignore + +try: + from numpy import random as np_random +except ImportError: + np_random = None + +try: + import deepspeed as ds +except ImportError: + ds = None + +from danling import defaults +from danling.utils import catch + +from .base_runner import BaseRunner +from .utils import RunnerMode, on_main_process, to_device + + +class TorchRunner(BaseRunner): + r""" + Set up everything for running a job. + + `TorchRunner` uses `torch.distributed` as distributed backend to provide + distributed training experience. + """ + + model: nn.Module + ema: nn.Module | None = None + criterion: nn.Module + optimizer: optim.Optimizer + scheduler: optim.lr_scheduler._LRScheduler + + def __post_init__(self): + super().__post_init__() + if self.datasets: + self.build_dataloaders() + if self.config.get("log_interval") is None: + self.config.log_interval = max(ceil(max(len(d) for d in self.dataloaders.values()) / 10), 1) + self.model = self.model.to(self.device) + if self.ema is not None: + self.ema = self.ema.to(self.device) + if self.distributed and not isinstance( + self.model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel) + ): + self.model = nn.parallel.DistributedDataParallel(self.model) + + def train(self, train_splits: list[str] | None = None, evaluate_splits: list[str] | None = None) -> NestedDict: + r""" + Perform training on `split`. + + Args: + train_splits: list of split to run train. + Defaults to `["train"]`. + evaluate_splits: list of split to run evaluate. + Defaults to `self.dataloaders` except for those in `train_splits`. + + Return: + NestedDict: train results + """ + + early_stop_counter = 0 + if train_splits is None: + train_splits = ["train"] if "train" in self.dataloaders else [] + if not train_splits: + warn("No training split is found. Will only evaluate for one epoch.", stacklevel=2) + self.config.epoch_end = self.config.epoch_begin + 1 + if evaluate_splits is None: + evaluate_splits = [s for s in self.dataloaders if s not in train_splits] + train_splits.sort() + evaluate_splits.sort() + print(f"Begin training from {self.config.epoch_begin} to {self.config.epoch_end}") + print(f"Training splits: {train_splits}") + print(f"Evaluation splits: {evaluate_splits}") + patience = self.config.get("patience", float("inf")) + for epoch in range(self.config.epoch_begin, self.config.epoch_end): # type: ignore + self.config.epochs = epoch + result = NestedDict() + result.setattr("convert_mapping", True) + for split in train_splits: + result[split] = self.train_epoch(split) + for split in evaluate_splits: + result[split] = self.evaluate_epoch(split) + self.append_result(result) + print(self.format_epoch_result(result)) + self.save_result() + if self.config.save_interval is not None: + self.save_checkpoint(epoch) + """@nni.report_intermediate_result(self.latest_score)""" + early_stop_counter = 0 if self.is_best else early_stop_counter + 1 + if early_stop_counter > patience: + print("early stop") + break + """@nni.report_final_result(self.latest_score)""" + return self.results + + def train_epoch(self, split: str = "train") -> NestedDict: + r""" + Train one epoch on `split`. + + Args: + split (str): split to run train + + Return: + NestedDict: train result + """ + + self.mode = "train" # type: ignore + self.split = split + loader = self.dataloaders[split] + length = len(loader) - 1 + last_print_iteration = -1 + log_interval = self.config.get("log_interval", -1) + self.meters.reset() + if self.train_metrics is not None: + self.metrics = self.train_metrics + if self.metrics is not None: + self.metrics.reset() + batch_time = time() + if hasattr(loader.batch_sampler, "set_epoch"): + loader.batch_sampler.set_epoch(self.config.epochs) + if hasattr(loader.sampler, "set_epoch"): + loader.sampler.set_epoch(self.config.epochs) + + for iteration, data in enumerate(loader): + _, loss = self.train_step(data) + + if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length): + interval = iteration - last_print_iteration + if self.device == torch.device("cuda"): + torch.cuda.synchronize() + if self.scheduler is not None: + self.meters.lr.update(self.scheduler.get_last_lr()[0]) + self.meters.time.update((time() - batch_time) / interval) + batch_time = time() + reduced_loss = self.reduce(loss).item() + self.meters.loss.update(reduced_loss) + self.step_log(split, iteration, length) + last_print_iteration = iteration + + result = self.get_epoch_result() + return result + + def train_step(self, data) -> Tuple[Any, torch.Tensor]: + with self.autocast(), self.accumulate(): + input = data["input"] if isinstance(data, Mapping) else data[0] + target = data["target"] if isinstance(data, Mapping) else data[1] + pred = self.model(**input) if isinstance(input, Mapping) else self.model(input) + loss = self.criterion(pred, target) + if self.metrics is not None: + self.metrics.update(pred.squeeze(-1), target) + self.advance(loss) + return pred, loss + + def advance(self, loss) -> None: + r""" + Backward loss and step optimizer & scheduler. + + Args: + loss: The loss tensor from which to backpropagate. + """ + + self.backward(loss / self.accum_steps) + if self.accum_steps <= 1 or self.config.steps % self.accum_steps == 0: + if self.config.get("max_grad_value") is not None: + clip_grad_value_(self.model.parameters(), self.config["max_grad_value"]) + if self.config.get("max_grad_norm") is not None: + clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"]) + self.optimizer.step() + self.optimizer.zero_grad() + if self.ema is not None: + self.ema.update() + if self.scheduler is not None: + self.scheduler.step() + self.config.steps += 1 + self.config.iters += 1 + + def evaluate(self, evaluate_splits: list[str] | None = None) -> NestedDict: + r""" + Perform evaluation on `evaluate_splits`. + + Args: + evaluate_splits: list of split to run evaluate. + Defaults to `["eval"]`. + + Return: + NestedDict: evaluation result + """ + + if evaluate_splits is None: + evaluate_splits = ["eval"] + + print("Begin evaluation") + print(f"Evaluation splits: {evaluate_splits}") + result = NestedDict() + result.setattr("convert_mapping", True) + for split in evaluate_splits: + result[split] = self.evaluate_epoch(split=split) + print(self.format_epoch_result(result)) + return result + + # torch.inference_mode cause experiments to hang + # @torch.inference_mode() + def evaluate_epoch(self, split: str = "val") -> NestedDict: + r""" + Evaluate one epoch on `split`. + + Args: + split (str): split to run evaluate + + Return: + NestedDict: evaluation result + """ + + self.mode = "eval" # type: ignore + self.split = split + loader = self.dataloaders[split] + length = len(loader) - 1 + last_print_iteration = -1 + log_interval = self.config.get("log_interval", -1) + self.meters.reset() + if self.evaluate_metrics is not None: + self.metrics = self.evaluate_metrics + if self.metrics is not None: + self.metrics.reset() + batch_time = time() + + for iteration, data in enumerate(loader): + _, loss = self.evaluate_step(data) + + if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length): + interval = iteration - last_print_iteration + if self.device == torch.device("cuda"): + torch.cuda.synchronize() + self.meters.time.update((time() - batch_time) / interval) + batch_time = time() + reduced_loss = self.reduce(loss).item() + self.meters.loss.update(reduced_loss) + self.step_log(split, iteration, length) + last_print_iteration = iteration + + result = self.get_epoch_result() + self.write_result(result, split, self.config.epochs) + return result + + def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]: + input = data["input"] if isinstance(data, Mapping) else data[0] + target = data["target"] if isinstance(data, Mapping) else data[1] + model = self.ema or self.model + pred = model(**input) if isinstance(input, Mapping) else model(input) + loss = self.criterion(pred, target) + if self.metrics is not None: + self.metrics.update(pred.squeeze(-1), target) + return pred, loss + + @torch.inference_mode() + def infer(self, split: str = "inf") -> list[float]: + r""" + Perform inference on `split`. + + Args: + split (str): split to run inference + + Return: + Tensor: inference outputs + """ + + self.mode = "inf" # type: ignore + loader = self.dataloaders[split] + output: list[float] = [] + model = self.ema or self.model + for _, data in tqdm(enumerate(loader), total=len(loader)): + input = data["input"] if isinstance(data, Mapping) else data[0] + pred = model(**input) if isinstance(input, Mapping) else model(input) + output.extend(pred.squeeze(-1).tolist()) + + if self.distributed: + torch.cuda.synchronize() + output = self.gather_for_metrics(output) + return output + + def backward(self, loss: torch.Tensor) -> None: + r""" + Backward loss. + + Args: + loss: Loss to backward. + """ + + loss.backward() + + def has_nan_inf_grad(self, model: nn.Module | None = None) -> bool: + r""" + Check if model has NaN or Inf gradients. + + Args: + model: Model to check. + Defaults to `self.model`. + + Return: + bool: True if NaN or Inf is detected in gradients. + """ + model = model or self.model + for name, param in model.named_parameters(): + if param.grad is not None: + if torch.isnan(param.grad).any(): + print(f"NaN detected in gradients of parameter: {name}") + return True + if torch.isinf(param.grad).any(): + print(f"Inf detected in gradients of parameter: {name}") + return True + return False + + def init_distributed(self) -> None: + r""" + Set up distributed training. + + Initialise process group and set up DDP variables. + """ + + backend = self.config.get("backend", os.getenv("BACKEND")) + init_method = self.config.get("init_method", os.getenv("INIT_METHOD")) + world_size = int(self.config.get("world_size", os.getenv("WORLD_SIZE", "1"))) + rank = int(self.config.get("rank", os.getenv("RANK", "0"))) + if world_size > 1: + if torch.cuda.is_available(): + torch.cuda.set_device(self.local_rank) + dist.init_process_group(backend, init_method, world_size=world_size, rank=rank) + object_list = [self.id, self.timestamp] + dist.broadcast_object_list(object_list) + self.id, self.timestamp = object_list + + @on_main_process + def init_tensorboard(self, *args, **kwargs) -> None: + r""" + Set up Tensoraoard SummaryWriter. + """ + from torch.utils.tensorboard.writer import SummaryWriter # pylint: disable=C0415 + + if "log_dir" not in kwargs: + kwargs["log_dir"] = self.dir + + self.writer = SummaryWriter(*args, **kwargs) + self.writer.add_scalar = catch(OSError, verbose=False)(self.writer.add_scalar) + + def set_seed(self, seed: int = None, bias: int = None) -> int: # type: ignore[assignment] + r""" + Set up random seed. + + Args: + seed: Random seed to set. + Defaults to `self.config.seed` (`config.seed`). + + bias: Make the seed different for each processes. + This is used to ensure the data augmentation are applied differently on every processes. + Defaults to `self.rank`. + Set to `False` to disable this feature. + Returns: + Random seed set. + """ + + seed = seed or self.config.seed # type: ignore[assignment] + if seed is None: + if self.inited: + seed = random.randint(0, 2**32 - 1) + if self.distributed: + object_list = [seed] + dist.broadcast_object_list(object_list) + seed = object_list[0] + self.config.seed = seed + else: + seed = defaults.SEED + bias = bias or self.rank + if bias: + seed += bias + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + if np_random is not None: + np_random.seed(seed) + random.seed(seed) + return seed + + def set_deterministic(self) -> None: + cudnn.benchmark = False + cudnn.deterministic = True + if torch.__version__ >= "1.8.0": + torch.use_deterministic_algorithms(True) + + def state_dict(self, cls: Callable = dict) -> Mapping: + if self.model is None: + raise ValueError("Model must be defined when calling state_dict") + return cls( + runner=self.config.dict(), + model=self.unwrap(self.model).state_dict(), + optimizer=self.optimizer.state_dict() if self.optimizer else None, + scheduler=self.scheduler.state_dict() if self.scheduler else None, + ) + + def unwrap(self, model: nn.Module) -> nn.Module: + if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)): + return model.module + return model + + def build_dataloaders(self): + datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders} + default_kwargs = self.config.get("dataloader", NestedDict()) + dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs}) + for k, d in datasets.items(): + dataloader_kwargs.setdefault(k, NestedDict()) + dataloader_kwargs[k].merge(default_kwargs, overwrite=False) + shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True)) + if self.distributed: + sampler = utils.data.distributed.DistributedSampler(d, shuffle=shuffle) + else: + sampler = utils.data.RandomSampler(d) if shuffle else utils.data.SequentialSampler(d) + dataloader_kwargs[k].setdefault("drop_last", not getattr(d, "train", True)) + self.dataloaders[k] = utils.data.DataLoader( + d, sampler=sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k] + ) + + def collate_fn(self, batch): + return to_device(utils.data.dataloader.default_collate(batch), self.device) + + @contextmanager + def autocast(self): + if self.config.get("precision") is None: + yield nullcontext() + else: + yield torch.autocast(self.device.type, dtype=get_precision(self.config.precision)) + + @contextmanager + def accumulate(self): + if self.accum_steps <= 1 or self.config.steps % self.accum_steps == 0: + yield nullcontext() + else: + yield self.model.no_sync() + + def get_optimizer(self, name: str): + if name.lower() == "sgd": + return optim.SGD + if name.lower() == "asgd": + return optim.ASGD + if name.lower() in {"torch_adam", "torch_adamw"}: + return optim.Adam + if ds is not None: + if name.lower() == "adagrad": + return ds.ops.adagrad.DeepSpeedCPUAdagrad + if name.lower() in {"adam", "adamw"}: + if torch.cuda.device_count() > 0: + return ds.ops.adam.FusedAdam + return ds.ops.adam.DeepSpeedCPUAdam + if name.lower() in {"cpu", "cpu_adam", "cpuadam", "cpu_adamw", "cpuadamw"}: + return ds.ops.adam.DeepSpeedCPUAdam + if name.lower() == "lamb": + if torch.cuda.device_count() > 0: + return ds.ops.lamb.FusedLamb + return ds.ops.lamb.DeepSpeedCPULamb + if name.lower() in {"cpulamb", "cpu_lamb"}: + return ds.ops.lamb.DeepSpeedCPULamb + if name.lower() == "lion": + if torch.cuda.device_count() > 0: + return ds.ops.lion.FusedLion + return ds.ops.lion.DeepSpeedCPULion + if name.lower() in {"cpulion", "cpu_lion"}: + return ds.ops.lion.DeepSpeedCPULion + if name.lower() in {"adam", "adamw"}: + return optim.AdamW + if name.lower() == "adadelta": + return optim.Adadelta + if name.lower() == "adafactor": + return optim.Adafactor + if name.lower() == "adagrad": + return optim.Adagrad + if name.lower() == "adamax": + return optim.Adamax + if name.lower() == "lbfgs": + return optim.LBFGS + if name.lower() == "nadam": + return optim.NAdam + if name.lower() == "radam": + return optim.RAdam + if name.lower() == "rmsprop": + return optim.RMSprop + if name.lower() == "rprop": + return optim.Rprop + + def get_deepspeed_config(self, config: NestedDict | str = None) -> NestedDict: # pylint: disable=R0912, R0915 + r""" + Preprocess DeepSpeed config. + """ + + if config is None and "deepspeed" in self.config: + config = self.config.deepspeed + if isinstance(config, str): + config = NestedDict(config) + if config is None: + config = NestedDict() + if config.get("steps_per_print", "auto") == "auto": + config["steps_per_print"] = self.config.log_interval + if config.get("train_micro_batch_size_per_gpu", "auto") == "auto": + config["train_micro_batch_size_per_gpu"] = self.batch_size + if config.get("gradient_accumulation_steps", "auto") == "auto": + if self.accum_steps > 1: + config["gradient_accumulation_steps"] = self.accum_steps + else: + config.pop("gradient_accumulation_steps", None) + if "amp" in config: + amp = config["amp"] + if amp.get("enabled", "auto") == "auto": + amp["enabled"] = "true" + if amp.get("opt_level", "auto") == "auto": + amp["opt_level"] = "O1" + if "zero_optimization" in config: + zero = config["zero_optimization"] + if zero.get("allgather_bucket_size") == "auto": + zero["allgather_bucket_size"] = 1e6 + if zero.get("reduce_bucket_size") == "auto": + zero["reduce_bucket_size"] = 1e6 + if zero.get("stage3_max_live_parameters") == "auto": + zero["stage3_max_live_parameters"] = 1e8 + if zero.get("stage3_max_live_gradients") == "auto": + zero["stage3_max_live_gradients"] = 1e8 + if zero.get("stage3_max_reuse_distance") == "auto": + zero["stage3_max_reuse_distance"] = 1e8 + if zero.get("stage3_prefetch_bucket_size") == "auto": + zero["stage3_prefetch_bucket_size"] = 1e6 + if zero.get("stage3_param_persistence_threshold") == "auto": + zero["stage3_param_persistence_threshold"] = 1e8 + if "amp" in config: + if "fp16" not in config: + config["fp16"] = NestedDict() + if config["fp16"].get("enabled", "auto"): + config["fp16"]["enabled"] = config["amp"]["enabled"] + warn( + f"AMP is not compatible with ZeRO. Automatically set 'fp16' to {config['amp']['enabled']}", + stacklevel=2, + ) + del config["amp"] + if "optimizer" in config: + if config["optimizer"].get("type", "auto") == "auto": + config["optimizer"]["type"] = "Adam" + if "params" not in config["optimizer"]: + config["optimizer"]["params"] = NestedDict() + optimizer = config["optimizer"]["params"] + if optimizer.get("lr", "auto") == "auto": + optimizer["lr"] = self.config.get("optim.lr", 1e-3) + if optimizer.get("weight_decay", "auto") == "auto": + optimizer["weight_decay"] = self.config.get("optim.weight_decay", 1e-2) + if optimizer.get("betas") == "auto": + optimizer["betas"] = (0.9, 0.999) + if optimizer.get("eps") == "auto": + optimizer["eps"] = 1e-8 + if "scheduler" in config: + if config["scheduler"].get("type", "auto") == "auto": + config["scheduler"]["type"] = "WarmupCosineLR" + if "params" not in config["scheduler"]: + config["scheduler"]["params"] = NestedDict() + scheduler = config["scheduler"]["params"] + if scheduler.get("total_num_steps", "auto") == "auto": + scheduler["total_num_steps"] = self.total_steps + if scheduler.get("warmup_num_steps", "auto") == "auto": + scheduler["warmup_num_steps"] = scheduler["total_num_steps"] // 20 + if config["scheduler"]["type"] in ("WarmupLR", "WarmupDecayLR"): + if scheduler.get("warmup_max_lr", "auto") == "auto": + if self.optimizer: + scheduler["warmup_max_lr"] = self.optimizer.param_groups[0]["lr"] + elif "optimizer" in config: + scheduler["warmup_max_lr"] = config["optimizer"]["params"]["lr"] + else: + scheduler["warmup_max_lr"] = self.config.get("optim.lr", 1e-3) + if scheduler.get("warmup_min_lr", "auto") == "auto": + scheduler["warmup_min_lr"] = 1e-9 + else: + scheduler.pop("warmup_max_lr", None) + scheduler.pop("warmup_min_lr", None) + if config.get("gradient_clipping", "auto") == "auto" and self.config.get("max_grad_norm") is not None: + config["gradient_clipping"] = self.config["max_grad_norm"] + return config + + @property + def device(self): + return torch.device("cuda" if torch.cuda.is_available() else "cpu", self.local_rank) + + @property + def mode(self) -> RunnerMode: + return self._mode + + @mode.setter + def mode(self, mode: str | RunnerMode) -> None: + if isinstance(mode, str): + mode = RunnerMode(mode) + self._mode = mode + if self.model is not None: + self.model.train(mode == RunnerMode.train) + if self.ema is not None: + self.ema.train(mode == RunnerMode.train) + + @property + def rank(self) -> int: + if self.distributed: + return dist.get_rank() + return 0 + + @property + def local_rank(self) -> int: + if local_rank := os.getenv("LOCAL_RANK"): + return int(local_rank) + return 0 + + @property + def world_size(self) -> int: + r""" + Number of Processes. + """ + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return dist.get_world_size() + return 1 + + @property + def distributed(self) -> bool: + return self.world_size > 1 + + @cached_property + def accum_steps(self) -> int: + return self.config.get("accum_steps", 1) + + @staticmethod + def reduce(tensor: torch.Tensor) -> torch.Tensor: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + dist.all_reduce(tensor) + return tensor + + +def get_precision(precision: str) -> torch.dtype: + if precision in ("fp16", "float16", "half"): + return torch.float16 + if precision in ("bf16", "bfloat16"): + return torch.bfloat16 + raise ValueError(f"Precision {precision} is not supported") diff --git a/danling/runner/utils.py b/danling/runner/utils.py index 8ae9e721..3683eb95 100644 --- a/danling/runner/utils.py +++ b/danling/runner/utils.py @@ -19,13 +19,18 @@ import os import sys +from collections.abc import Mapping from contextlib import suppress from datetime import datetime from enum import auto from functools import wraps +from math import isnan from typing import Any from warnings import warn +import torch +from chanfig import FlatDict, NestedDict + from danling.utils import base62 try: @@ -120,3 +125,39 @@ def wrapper(self, *args, **kwargs) -> Any | None: return None return wrapper + + +def format_result(result, format_spec: str = ".4f", depth: int = 0): + longest_key = max(len(k) for k in result.keys()) + repr_list = [_format_result(result, format_spec)] + for k, v in result.items(): + if isinstance(v, Mapping): + initials = " " * (longest_key - len(k)) + "\t" * depth + repr_list.append(f"{initials}{k}: {format_result(v, format_spec, depth + 1)}") + return "\n".join(repr_list) + + +def _format_result(result, format_spec: str = ".4f"): + repr_list = [] + for k, v in result.items(): + if isinstance(v, (Mapping,)): + continue + if isinstance(v, (float,)): + repr_list.append(f"{k}: {format(v, format_spec)}" if not isnan(v) else f"{k}: NaN\t") + else: + repr_list.append(f"{k}: {v}\t") + return "\t".join(repr_list) + + +def to_device(data: Any, device: torch.device): + if isinstance(data, list): + return [to_device(i, device) for i in data] + if isinstance(data, tuple): + return tuple(to_device(i, device) for i in data) + if isinstance(data, NestedDict): + return NestedDict({k: to_device(v, device) for k, v in data.all_items()}) + if isinstance(data, dict): + return FlatDict({k: to_device(v, device) for k, v in data.items()}) + if hasattr(data, "to"): + return data.to(device) + return data diff --git a/danling/tensors/nested_tensor.py b/danling/tensors/nested_tensor.py index 5848b58a..544d61fb 100644 --- a/danling/tensors/nested_tensor.py +++ b/danling/tensors/nested_tensor.py @@ -1144,6 +1144,9 @@ def reshape(self, *shape) -> Tensor: return self.tensor.reshape(*shape) + def __iter__(self): + return iter(self._storage) + NestedTensorFunc = TorchFuncRegistry() diff --git a/danling/utils/__init__.py b/danling/utils/__init__.py index 9f75cef5..010914fd 100644 --- a/danling/utils/__init__.py +++ b/danling/utils/__init__.py @@ -15,7 +15,6 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the LICENSE file for more details. -from . import defaults from .basex import Base58, Base62, Base64, BaseX, base58, base62, base64 from .contextmanagers import debug from .decorators import catch, ensure_dir, flexible_decorator, method_cache @@ -43,5 +42,4 @@ "base58", "base62", "base64", - "defaults", ] diff --git a/demo/accelerate_imdb.py b/demo/accelerate_imdb.py new file mode 100644 index 00000000..c7c8b037 --- /dev/null +++ b/demo/accelerate_imdb.py @@ -0,0 +1,120 @@ +# DanLing +# Copyright (C) 2022-Present DanLing + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the following licenses: +# - The Unlicense +# - GNU Affero General Public License v3.0 or later +# - GNU General Public License v2.0 or later +# - BSD 4-Clause "Original" or "Old" License +# - MIT License +# - Apache License 2.0 + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the LICENSE file for more details. + +import torch +from chanfig import Registry +from datasets import load_dataset +from torch import nn, optim +from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer + +import danling as dl + +OPTIMIZERS = Registry() +OPTIMIZERS.register(optim.AdamW, "adamw") +OPTIMIZERS.register(optim.SGD, "sgd") + + +class IMDBConfig(dl.Config): + epoch_end: int = 2 + log: bool = False + tensorboard: bool = False + log_interval: int = 1000 + score_split: str = "val" + score_name: str = "loss" + debug: bool = False + patience: int = 1 + + def __init__(self): + super().__init__() + self.pretrained = "prajjwal1/bert-tiny" + self.dataset.path = "stanfordnlp/imdb" + self.dataloader.batch_size = 8 + self.optim.name = "adamw" + self.optim.lr = 1e-3 + self.optim.weight_decay = 1e-4 + self.sched.strategy = "cosine" + + def post(self): + self.transformers = AutoConfig.from_pretrained(self.pretrained) + self.experiment_name = f"{self.network.name}_{self.optim.name}@{self.optim.lr}" + + +class IMDBRunner(dl.AccelerateRunner): + def __init__(self, config: dl.Config): + super().__init__(config) + + self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained) + self.datasets.train = load_dataset(split="train", **self.dataset) + self.datasets.val = load_dataset(split="train", **self.dataset) + # only run on a few samples to speed up testing process + self.datasets.train._data = self.datasets.train._data[:64] + self.datasets.val._data = self.datasets.val._data[:64] + self.datasets.train = self.preprocess_data(self.datasets.train) + self.datasets.val = self.preprocess_data(self.datasets.val) + + self.model = AutoModelForSequenceClassification.from_config(self.config.transformers) + self.optimizer = OPTIMIZERS.build(params=self.model.parameters(), **self.optim) + self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.sched) + self.criterion = nn.CrossEntropyLoss() + + self.metrics = dl.metrics.binary_metrics() + self.meters.loss.reset() + self.meters.time.reset() + + def preprocess_data(self, dataset): + def tokenization(example): + example["text"] = self.tokenizer(example["text"], truncation=True, max_length=510)["input_ids"] + return example + + def transform(data): + text = dl.NestedTensor(data.pop("text")) + data["input_ids"] = text.tensor + data["attention_mask"] = text.mask + data["labels"] = torch.tensor(data.pop("label")) + return data + + dataset = dataset.map(tokenization, batched=True) + dataset.set_transform(transform) + dataset.__getitems__ = dataset.__getitem__ + return dataset + + def train_step(self, data) -> torch.Tensor: + with self.autocast(), self.accumulate(): + pred = self.model(**data) + loss = pred["loss"] + self.advance(loss) + self.metrics.update(pred["logits"][:, 0], data["labels"]) + return pred, loss + + def evaluate_step(self, data) -> torch.Tensor: + pred = self.model(**data) + loss = pred["loss"] + self.metrics.update(pred["logits"][:, 0], data["labels"]) + return pred, loss + + @staticmethod + def collate_fn(batch): + return batch + + +if __name__ == "__main__": + config = IMDBConfig() + config.parse() + with dl.debug(config.get("debug", False)): + runner = IMDBRunner(config) + runner.train() + runner.evaluate(["val"]) diff --git a/demo/vision/torch_mnist.py b/demo/torch_mnist.py similarity index 88% rename from demo/vision/torch_mnist.py rename to demo/torch_mnist.py index ef492767..09bcbc28 100644 --- a/demo/vision/torch_mnist.py +++ b/demo/torch_mnist.py @@ -16,7 +16,7 @@ # See the LICENSE file for more details. import torchvision -from chanfig import Config, Registry +from chanfig import Registry from torch import nn, optim import danling as dl @@ -26,11 +26,10 @@ OPTIMIZERS.register(optim.SGD, "sgd") -class MNISTConfig(Config): +class MNISTConfig(dl.Config): epoch_end: int = 2 log: bool = False tensorboard: bool = False - log_interval: int = 1000 score_split: str = "val" score_name: str = "loss" debug: bool = False @@ -41,7 +40,7 @@ def __init__(self): self.network.name = "resnet18" self.dataset.download = True self.dataset.root = "data" - self.dataloader.batch_size = 8 + self.dataloader.batch_size = 256 self.optim.name = "adamw" self.optim.lr = 1e-3 self.optim.weight_decay = 1e-4 @@ -52,7 +51,7 @@ def post(self): class MNISTRunner(dl.TorchRunner): - def __init__(self, config: Config): + def __init__(self, config: dl.Config): super().__init__(config) self.dataset.transform = torchvision.transforms.Compose( @@ -64,13 +63,13 @@ def __init__(self, config: Config): self.datasets.train = torchvision.datasets.MNIST(train=True, **self.dataset) self.datasets.val = torchvision.datasets.MNIST(train=False, **self.dataset) # only run on a few samples to speed up testing process - self.datasets.train.data = self.datasets.train.data[:64] - self.datasets.val.data = self.datasets.val.data[:64] + self.datasets.train.data = self.datasets.train.data + self.datasets.val.data = self.datasets.val.data self.model = getattr(torchvision.models, self.network.name)(pretrained=False, num_classes=10) self.model.conv1 = nn.Conv2d(1, 64, 1, bias=False) self.optimizer = OPTIMIZERS.build(params=self.model.parameters(), **self.optim) - self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.trainable_steps, **self.sched) + self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.sched) self.criterion = nn.CrossEntropyLoss() self.metrics = dl.metrics.multiclass_metrics(num_classes=10) diff --git a/docs/docs/runner/state.md b/docs/docs/runner/config.md similarity index 57% rename from docs/docs/runner/state.md rename to docs/docs/runner/config.md index d646e72d..0982bfcb 100644 --- a/docs/docs/runner/state.md +++ b/docs/docs/runner/config.md @@ -4,6 +4,6 @@ authors: date: 2022-05-04 --- -# RunnerState +# Config -::: danling.runner.state +::: danling.runner.config diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index cf189c85..999ced85 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -11,7 +11,7 @@ nav: - DanLing: index.md - Runner: - runner/index.md - - RunnerState: runner/runner_state.md + - Config: runner/runner_state.md - BaseRunner: runner/base_runner.md - AccelerateRunner: runner/accelerate_runner.md - Utilities: runner/utils.md @@ -196,6 +196,7 @@ plugins: - https://pytorch.org/docs/stable/objects.inv - https://pytorch.org/torcheval/stable/objects.inv - https://huggingface.co/docs/transformers/master/en/objects.inv + - https://huggingface.co/docs/accelerate/master/en/objects.inv - https://chanfig.danling.org/objects.inv - https://lightning.ai/docs/torchmetrics/stable/objects.inv rendering: diff --git a/pyproject.toml b/pyproject.toml index 0ffc2783..35549709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,18 @@ dependencies = [ "strenum; python_version<'3.11'", "tqdm", ] +optional-dependencies.accelerate = [ + "accelerate", + "torch", + "torcheval", + "torchmetrics", +] +optional-dependencies.deepspeed = [ + "deepspeed", + "torch", + "torcheval", + "torchmetrics", +] optional-dependencies.jax = [ "flax", "jax", @@ -55,7 +67,6 @@ optional-dependencies.tensorflow = [ "tensorflow", ] optional-dependencies.torch = [ - "accelerate", "torch", "torcheval", "torchmetrics", diff --git a/requirements.txt b/requirements.txt index 3397cc5c..7f72ba3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,19 @@ +# This file is for testing purposes only. +# Please refer to pyproject.toml for the actual dependencies. + +accelerate cached-property; python_version < "3.8" chanfig >= 0.0.96 +datasets gitpython lazy-imports +portalocker>=2.0.0 strenum; python_version < "3.11" +torch +# torchdata torcheval +torchmetrics +# torchtext +torchvision tqdm +transformers diff --git a/tests/optim/test_lr_scheduler.py b/tests/optim/test_lr_scheduler.py index e8ac4ed4..5058c23a 100644 --- a/tests/optim/test_lr_scheduler.py +++ b/tests/optim/test_lr_scheduler.py @@ -32,11 +32,11 @@ class Test: optimizer = optim.SGD([{"params": torch.tensor([0])}], lr=1, momentum=0.9) - def _get_lrs(self, strategy, method, steps: int = 100, final_lr_ratio: float = 0.001): + def _get_lrs(self, strategy, method, total_steps: int = 100, final_lr_ratio: float = 0.001): lrs = [] scheduler = LRScheduler( self.optimizer, - total_steps=steps, + total_steps=total_steps, final_lr_ratio=final_lr_ratio, strategy=strategy, method=method, diff --git a/tests/runner/test_base_runner.py b/tests/runner/test_base_runner.py index f0653e07..c85ae264 100644 --- a/tests/runner/test_base_runner.py +++ b/tests/runner/test_base_runner.py @@ -15,7 +15,6 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the LICENSE file for more details. -from chanfig import Config as Config_ from chanfig import NestedDict import danling as dl @@ -28,7 +27,7 @@ def init_distributed(self) -> None: pass -class Config(Config_): +class Config(dl.Config): __test__ = False def __init__(self): @@ -105,7 +104,7 @@ def test_results(self): def test_conflict(self): runner = self.runner - state = runner.state + config = runner.config runner.conflict = False assert not runner.conflict - assert state.conflict == 1 + assert config.conflict == 1 diff --git a/danling/utils/defaults.py b/tests/runner/test_imdb.py similarity index 66% rename from danling/utils/defaults.py rename to tests/runner/test_imdb.py index 1fc1288e..d4423960 100644 --- a/danling/utils/defaults.py +++ b/tests/runner/test_imdb.py @@ -15,4 +15,19 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the LICENSE file for more details. -DEFAULT_EXCLUDE = (KeyboardInterrupt, SystemExit) +import sys + +sys.path.insert(0, "demo") + +from accelerate_imdb import IMDBConfig, IMDBRunner # noqa: E402 + + +class Test: + config = IMDBConfig().boot() + runner = IMDBRunner(config) + + def test_train(self): + self.runner.train() + + def test_evaluate(self): + self.runner.evaluate(["val"]) diff --git a/tests/runner/test_mnist.py b/tests/runner/test_mnist.py index 65c14724..75ec6d67 100644 --- a/tests/runner/test_mnist.py +++ b/tests/runner/test_mnist.py @@ -17,7 +17,7 @@ import sys -sys.path.insert(0, "demo/vision") +sys.path.insert(0, "demo") from torch_mnist import MNISTConfig, MNISTRunner # noqa: E402