diff --git a/pyproject.toml b/pyproject.toml index 8872993..ad24c49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ include = ["trainer*"] [project] name = "coqui-tts-trainer" -version = "0.1.7" +version = "0.2.0" description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui." readme = "README.md" requires-python = ">=3.9, <3.13" @@ -21,9 +21,7 @@ maintainers = [ classifiers = [ "Environment :: Console", "Natural Language :: English", - # How mature is this project? Common values are - # 3 - Alpha, 4 - Beta, 5 - Production/Stable - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", @@ -39,7 +37,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "coqpit>=0.0.17", + "coqpit-config>=0.1.1", "fsspec>=2023.6.0", "numpy>=1.24.3; python_version < '3.12'", "numpy>=1.26.0; python_version >= '3.12'", @@ -58,11 +56,17 @@ dev = [ "pytest>=8", "ruff==0.6.9", ] -# Dependencies for running the tests test = [ "accelerate>=0.20.0", "torchvision>=0.15.1", ] +mypy = [ + "matplotlib>=3.9.2", + "mlflow>=2.18.0", + "mypy>=1.13.0", + "types-psutil>=6.1.0.20241102", + "wandb>=0.18.7", +] [tool.uv] default-groups = ["dev", "test"] @@ -105,3 +109,16 @@ skip_empty = true [tool.coverage.run] source = ["trainer", "tests"] command_line = "-m pytest" + +[[tool.mypy.overrides]] +module = [ + "accelerate", + "aim", + "aim.sdk.run", + "apex", + "clearml", + "fsspec", + "plotly", + "soundfile", +] +ignore_missing_imports = true diff --git a/tests/test_train_mnist.py b/tests/test_train_mnist.py index 5383b39..d3a851c 100644 --- a/tests/test_train_mnist.py +++ b/tests/test_train_mnist.py @@ -1,3 +1,4 @@ +import pytest import torch from tests.utils.mnist import MnistModel, MnistModelConfig @@ -22,6 +23,7 @@ def test_train_mnist(tmp_path): # Without parsing command line args args = TrainerArgs() + args.small_run = 4 trainer2 = Trainer( args, @@ -48,3 +50,6 @@ def test_train_mnist(tmp_path): loss4 = trainer3.keep_avg_train["avg_loss"] assert loss3 > loss4 + + with pytest.raises(ValueError, match="cannot both be None"): + Trainer(args, MnistModelConfig(), output_path=tmp_path, model=None) diff --git a/trainer/__init__.py b/trainer/__init__.py index 40c72da..093ffe8 100644 --- a/trainer/__init__.py +++ b/trainer/__init__.py @@ -1,7 +1,9 @@ import importlib.metadata from trainer.config import TrainerArgs, TrainerConfig -from trainer.model import * -from trainer.trainer import * +from trainer.model import TrainerModel +from trainer.trainer import Trainer __version__ = importlib.metadata.version("coqui-tts-trainer") + +__all__ = ["TrainerArgs", "TrainerConfig", "Trainer", "TrainerModel"] diff --git a/trainer/callbacks.py b/trainer/callbacks.py index 505fdac..d0fff8f 100644 --- a/trainer/callbacks.py +++ b/trainer/callbacks.py @@ -1,17 +1,20 @@ -from typing import Callable +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from trainer import Trainer class TrainerCallback: def __init__(self) -> None: - self.callbacks_on_init_start = [] - self.callbacks_on_init_end = [] - self.callbacks_on_epoch_start = [] - self.callbacks_on_epoch_end = [] - self.callbacks_on_train_epoch_start = [] - self.callbacks_on_train_epoch_end = [] - self.callbacks_on_train_step_start = [] - self.callbacks_on_train_step_end = [] - self.callbacks_on_keyboard_interrupt = [] + self.callbacks_on_init_start: list[Callable] = [] + self.callbacks_on_init_end: list[Callable] = [] + self.callbacks_on_epoch_start: list[Callable] = [] + self.callbacks_on_epoch_end: list[Callable] = [] + self.callbacks_on_train_epoch_start: list[Callable] = [] + self.callbacks_on_train_epoch_end: list[Callable] = [] + self.callbacks_on_train_step_start: list[Callable] = [] + self.callbacks_on_train_step_end: list[Callable] = [] + self.callbacks_on_keyboard_interrupt: list[Callable] = [] def parse_callbacks_dict(self, callbacks_dict: dict[str, Callable]) -> None: for key, value in callbacks_dict.items(): @@ -36,7 +39,7 @@ def parse_callbacks_dict(self, callbacks_dict: dict[str, Callable]) -> None: else: raise ValueError(f"Invalid callback key: {key}") - def on_init_start(self, trainer) -> None: + def on_init_start(self, trainer: "Trainer") -> None: if hasattr(trainer.model, "module"): if hasattr(trainer.model.module, "on_init_start"): trainer.model.module.on_init_start(trainer) diff --git a/trainer/config.py b/trainer/config.py index 872c9a7..d85d23b 100644 --- a/trainer/config.py +++ b/trainer/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Union +from typing import Any, Optional, Union from coqpit import Coqpit @@ -192,13 +192,13 @@ class TrainerConfig(Coqpit): optimizer: Optional[Union[str, list[str]]] = field( default=None, metadata={"help": "Optimizer(s) to use. Defaults to None"} ) - optimizer_params: Union[dict, list[dict]] = field( + optimizer_params: Union[dict[str, Any], list[dict[str, Any]]] = field( default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"} ) lr_scheduler: Optional[Union[str, list[str]]] = field( default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"} ) - lr_scheduler_params: dict = field( + lr_scheduler_params: dict[str, Any] = field( default_factory=dict, metadata={"help": "Learning rate scheduler(s) arguments. Defaults to {}"} ) use_grad_scaler: bool = field( diff --git a/trainer/distribute.py b/trainer/distribute.py index f1505f5..3b99d8a 100644 --- a/trainer/distribute.py +++ b/trainer/distribute.py @@ -5,10 +5,11 @@ import subprocess import time -from trainer import TrainerArgs, logger +from trainer import TrainerArgs +from trainer.logger import logger -def distribute(): +def distribute() -> None: """ Call 👟Trainer training script in DDP mode. """ diff --git a/trainer/generic_utils.py b/trainer/generic_utils.py index 212c5d7..d186426 100644 --- a/trainer/generic_utils.py +++ b/trainer/generic_utils.py @@ -1,12 +1,14 @@ import datetime import os import subprocess +from collections.abc import ItemsView from typing import Any, Union import fsspec import torch from packaging.version import Version +from trainer.config import TrainerConfig from trainer.logger import logger @@ -20,7 +22,7 @@ def is_pytorch_at_least_2_4() -> bool: return Version(torch.__version__) >= Version("2.4") -def isimplemented(obj, method_name) -> bool: +def isimplemented(obj: Any, method_name: str) -> bool: """Check if a method is implemented in a class.""" if method_name in dir(obj) and callable(getattr(obj, method_name)): try: @@ -43,7 +45,7 @@ def to_cuda(x: torch.Tensor) -> torch.Tensor: return x -def get_cuda(): +def get_cuda() -> tuple[bool, torch.device]: use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") return use_cuda, device @@ -97,7 +99,7 @@ def count_parameters(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) -def set_partial_state_dict(model_dict, checkpoint_state, c): +def set_partial_state_dict(model_dict: dict, checkpoint_state: dict, c: TrainerConfig) -> dict: # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k in checkpoint_state: if k not in model_dict: @@ -123,21 +125,21 @@ def set_partial_state_dict(model_dict, checkpoint_state, c): class KeepAverage: - def __init__(self): - self.avg_values = {} - self.iters = {} + def __init__(self) -> None: + self.avg_values: dict[str, float] = {} + self.iters: dict[str, int] = {} - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.avg_values[key] - def items(self): + def items(self) -> ItemsView[str, Any]: return self.avg_values.items() - def add_value(self, name, init_val=0, init_iter=0): + def add_value(self, name: str, init_val: float = 0, init_iter: int = 0) -> None: self.avg_values[name] = init_val self.iters[name] = init_iter - def update_value(self, name, value, weighted_avg=False): + def update_value(self, name: str, value: float, weighted_avg: bool = False) -> None: if name not in self.avg_values: # add value if not exist before self.add_value(name, init_val=value) @@ -151,10 +153,10 @@ def update_value(self, name, value, weighted_avg=False): self.iters[name] += 1 self.avg_values[name] /= self.iters[name] - def add_values(self, name_dict): + def add_values(self, name_dict: dict[str, float]) -> None: for key, value in name_dict.items(): self.add_value(key, init_val=value) - def update_values(self, value_dict): + def update_values(self, value_dict: dict[str, float]) -> None: for key, value in value_dict.items(): self.update_value(key, value) diff --git a/trainer/io.py b/trainer/io.py index 532c721..ca0dd2e 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -10,6 +10,7 @@ import fsspec import torch from coqpit import Coqpit +from torch.types import Storage from trainer.generic_utils import is_pytorch_at_least_2_4 from trainer.logger import logger @@ -60,7 +61,7 @@ def copy_model_files(config: Coqpit, out_path: Union[str, os.PathLike[Any]], new def load_fsspec( path: Union[str, os.PathLike[Any]], - map_location: Union[str, Callable, torch.device, dict[Union[str, torch.device], Union[str, torch.device]]] = None, + map_location: Optional[Union[str, Callable[[Storage, str], Storage], torch.device, dict[str, str]]] = None, cache: bool = True, **kwargs, ) -> Any: @@ -195,7 +196,7 @@ def save_checkpoint( def save_best_model( current_loss: Union[dict, float], - best_loss: Union[dict, float], + best_loss: Union[dict[str, Optional[float]], float], config: Union[dict, Coqpit], model: torch.nn.Module, optimizer: torch.optim.Optimizer, @@ -208,12 +209,13 @@ def save_best_model( save_func: Optional[Callable] = None, **kwargs, ) -> Union[dict, float]: - if isinstance(current_loss, dict): + if isinstance(current_loss, dict) and isinstance(best_loss, dict): use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] ) else: + assert isinstance(current_loss, float) and isinstance(best_loss, float) is_save_model = current_loss < best_loss is_save_model = is_save_model and current_step > keep_after @@ -249,7 +251,7 @@ def save_best_model( return best_loss -def get_last_checkpoint(path: Union[str, os.PathLike]) -> tuple[str, str]: +def get_last_checkpoint(path: Union[str, os.PathLike[Any]]) -> tuple[str, str]: """Get latest checkpoint or/and best model in path. It is based on globbing for `*.pth` and the RegEx @@ -274,7 +276,7 @@ def get_last_checkpoint(path: Union[str, os.PathLike]) -> tuple[str, str]: # back if it exists on the path file_names = [scheme + "://" + file_name for file_name in file_names] last_models = {} - last_model_nums = {} + last_model_nums: dict[str, int] = {} for key in ["checkpoint", "best_model"]: last_model_num = None last_model = None @@ -357,6 +359,4 @@ def sort_checkpoints( if regex_match is not None and regex_match.groups() is not None: ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) - checkpoints_sorted = sorted(ordering_and_checkpoint_path) - checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] - return checkpoints_sorted + return [checkpoint[1] for checkpoint in sorted(ordering_and_checkpoint_path)] diff --git a/trainer/logging/__init__.py b/trainer/logging/__init__.py index 48592c2..e6888f5 100644 --- a/trainer/logging/__init__.py +++ b/trainer/logging/__init__.py @@ -1,31 +1,35 @@ import logging import os +from typing import Union +from trainer.config import TrainerConfig +from trainer.logging.base_dash_logger import BaseDashboardLogger from trainer.logging.console_logger import ConsoleLogger from trainer.logging.dummy_logger import DummyLogger -# pylint: disable=import-outside-toplevel +__all__ = ["ConsoleLogger", "DummyLogger"] logger = logging.getLogger("trainer") -def get_mlflow_tracking_url(): +def get_mlflow_tracking_url() -> Union[str, None]: if "MLFLOW_TRACKING_URI" in os.environ: return os.environ["MLFLOW_TRACKING_URI"] return None -def get_ai_repo_url(): +def get_ai_repo_url() -> Union[str, None]: if "AIM_TRACKING_URI" in os.environ: return os.environ["AIM_TRACKING_URI"] return None -def logger_factory(config, output_path): +def logger_factory(config: TrainerConfig, output_path: str) -> BaseDashboardLogger: run_name = config.run_name project_name = config.project_name log_uri = config.logger_uri if config.logger_uri else output_path + dashboard_logger: BaseDashboardLogger if config.dashboard_logger == "tensorboard": from trainer.logging.tensorboard_logger import TensorboardLogger diff --git a/trainer/logging/base_dash_logger.py b/trainer/logging/base_dash_logger.py index 5e20e45..d0d42c1 100644 --- a/trainer/logging/base_dash_logger.py +++ b/trainer/logging/base_dash_logger.py @@ -1,9 +1,15 @@ from abc import ABC, abstractmethod -from typing import Union +from typing import TYPE_CHECKING, Union +from trainer.config import TrainerConfig from trainer.io import save_fsspec from trainer.utils.distributed import rank_zero_only +if TYPE_CHECKING: + import matplotlib + import numpy as np + import plotly + # pylint: disable=too-many-public-methods class BaseDashboardLogger(ABC): @@ -21,7 +27,7 @@ def add_figure( pass @abstractmethod - def add_config(self, config): + def add_config(self, config: TrainerConfig) -> None: pass @abstractmethod @@ -37,53 +43,53 @@ def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases= pass @abstractmethod - def add_scalars(self, scope_name: str, scalars: dict, step: int): + def add_scalars(self, scope_name: str, scalars: dict, step: int) -> None: pass @abstractmethod - def add_figures(self, scope_name: str, figures: dict, step: int): + def add_figures(self, scope_name: str, figures: dict, step: int) -> None: pass @abstractmethod - def add_audios(self, scope_name: str, audios: dict, step: int, sample_rate: int): + def add_audios(self, scope_name: str, audios: dict, step: int, sample_rate: int) -> None: pass @abstractmethod - def flush(self): + def flush(self) -> None: pass @abstractmethod - def finish(self): + def finish(self) -> None: pass @staticmethod @rank_zero_only - def save_model(state: dict, path: str): + def save_model(state: dict, path: str) -> None: save_fsspec(state, path) - def train_step_stats(self, step, stats): + def train_step_stats(self, step: int, stats) -> None: self.add_scalars(scope_name="TrainIterStats", scalars=stats, step=step) - def train_epoch_stats(self, step, stats): + def train_epoch_stats(self, step: int, stats) -> None: self.add_scalars(scope_name="TrainEpochStats", scalars=stats, step=step) - def train_figures(self, step, figures): + def train_figures(self, step: int, figures) -> None: self.add_figures(scope_name="TrainFigures", figures=figures, step=step) - def train_audios(self, step, audios, sample_rate): + def train_audios(self, step: int, audios, sample_rate) -> None: self.add_audios(scope_name="TrainAudios", audios=audios, step=step, sample_rate=sample_rate) - def eval_stats(self, step, stats): + def eval_stats(self, step: int, stats) -> None: self.add_scalars(scope_name="EvalStats", scalars=stats, step=step) - def eval_figures(self, step, figures): + def eval_figures(self, step: int, figures) -> None: self.add_figures(scope_name="EvalFigures", figures=figures, step=step) - def eval_audios(self, step, audios, sample_rate): + def eval_audios(self, step: int, audios, sample_rate: int) -> None: self.add_audios(scope_name="EvalAudios", audios=audios, step=step, sample_rate=sample_rate) - def test_audios(self, step, audios, sample_rate): + def test_audios(self, step: int, audios, sample_rate: int) -> None: self.add_audios(scope_name="TestAudios", audios=audios, step=step, sample_rate=sample_rate) - def test_figures(self, step, figures): + def test_figures(self, step: int, figures) -> None: self.add_figures(scope_name="TestFigures", figures=figures, step=step) diff --git a/trainer/logging/console_logger.py b/trainer/logging/console_logger.py index a15fde8..167f6de 100644 --- a/trainer/logging/console_logger.py +++ b/trainer/logging/console_logger.py @@ -1,6 +1,7 @@ import datetime import logging from dataclasses import dataclass +from typing import Optional from trainer.utils.distributed import rank_zero_only @@ -20,15 +21,15 @@ class tcolors: class ConsoleLogger: - def __init__(self): + def __init__(self) -> None: # TODO: color code for value changes # use these to compare values between iterations self.old_train_loss_dict = None self.old_epoch_loss_dict = None - self.old_eval_loss_dict = None + self.old_eval_loss_dict: dict[str, float] = {} @staticmethod - def log_with_flush(msg: str): + def log_with_flush(msg: str) -> None: if logger is not None: logger.info(msg) for handler in logger.handlers: @@ -37,12 +38,12 @@ def log_with_flush(msg: str): print(msg, flush=True) @staticmethod - def get_time(): + def get_time() -> str: now = datetime.datetime.now() return now.strftime("%Y-%m-%d %H:%M:%S") @rank_zero_only - def print_epoch_start(self, epoch, max_epoch, output_path=None): + def print_epoch_start(self, epoch: int, max_epoch: int, output_path: Optional[str] = None) -> None: self.log_with_flush( "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), ) @@ -50,11 +51,13 @@ def print_epoch_start(self, epoch, max_epoch, output_path=None): self.log_with_flush(f" --> {output_path}") @rank_zero_only - def print_train_start(self): + def print_train_start(self) -> None: self.log_with_flush(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") @rank_zero_only - def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict): + def print_train_step( + self, batch_steps: int, step: int, global_step: int, loss_dict: dict, avg_loss_dict: dict + ) -> None: indent = " | > " self.log_with_flush("") log_text = "{} --> TIME: {} -- STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( @@ -70,7 +73,7 @@ def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_d # pylint: disable=unused-argument @rank_zero_only - def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict): + def print_train_epoch_end(self, global_step: int, epoch: int, epoch_time, print_dict: dict) -> None: indent = " | > " log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n" for key, value in print_dict.items(): @@ -78,11 +81,11 @@ def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict): self.log_with_flush(log_text) @rank_zero_only - def print_eval_start(self): + def print_eval_start(self) -> None: self.log_with_flush(f"\n{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n") @rank_zero_only - def print_eval_step(self, step, loss_dict, avg_loss_dict): + def print_eval_step(self, step: int, loss_dict: dict, avg_loss_dict: dict) -> None: indent = " | > " log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n" for key, value in loss_dict.items(): @@ -94,7 +97,7 @@ def print_eval_step(self, step, loss_dict, avg_loss_dict): self.log_with_flush(log_text) @rank_zero_only - def print_epoch_end(self, epoch, avg_loss_dict): + def print_epoch_end(self, epoch: int, avg_loss_dict: dict) -> None: indent = " | > " log_text = "\n {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC) for key, value in avg_loss_dict.items(): @@ -102,7 +105,7 @@ def print_epoch_end(self, epoch, avg_loss_dict): color = "" sign = "+" diff = 0 - if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: + if key in self.old_eval_loss_dict: diff = value - self.old_eval_loss_dict[key] if diff < 0: color = tcolors.OKGREEN diff --git a/trainer/logging/dummy_logger.py b/trainer/logging/dummy_logger.py index beea20f..4742d51 100644 --- a/trainer/logging/dummy_logger.py +++ b/trainer/logging/dummy_logger.py @@ -1,7 +1,12 @@ -from typing import Union +from typing import TYPE_CHECKING, Union from trainer.logging.base_dash_logger import BaseDashboardLogger +if TYPE_CHECKING: + import matplotlib + import numpy as np + import plotly + class DummyLogger(BaseDashboardLogger): """DummyLogger that implements the API but does nothing""" diff --git a/trainer/logging/tensorboard_logger.py b/trainer/logging/tensorboard_logger.py index cb18e60..0c2390d 100644 --- a/trainer/logging/tensorboard_logger.py +++ b/trainer/logging/tensorboard_logger.py @@ -1,16 +1,18 @@ import traceback +import torch from torch.utils.tensorboard import SummaryWriter +from trainer.config import TrainerConfig from trainer.logging.base_dash_logger import BaseDashboardLogger class TensorboardLogger(BaseDashboardLogger): - def __init__(self, log_dir, model_name): + def __init__(self, log_dir: str, model_name: str) -> None: self.model_name = model_name self.writer = SummaryWriter(log_dir) - def model_weights(self, model, step): + def model_weights(self, model: torch.nn.Module, step: int) -> None: layer_num = 1 for name, param in model.named_parameters(): if param.numel() == 1: @@ -24,33 +26,33 @@ def model_weights(self, model, step): self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) layer_num += 1 - def add_config(self, config): + def add_config(self, config: TrainerConfig) -> None: self.add_text("model-config", f"
{config.to_json()}
", 0) def add_scalar(self, title: str, value: float, step: int) -> None: self.writer.add_scalar(title, value, step) - def add_audio(self, title, audio, step, sample_rate): + def add_audio(self, title: str, audio, step: int, sample_rate: int) -> None: self.writer.add_audio(title, audio, step, sample_rate=sample_rate) - def add_text(self, title, text, step): + def add_text(self, title: str, text: str, step: int) -> None: self.writer.add_text(title, text, step) - def add_figure(self, title, figure, step): + def add_figure(self, title: str, figure, step: int) -> None: self.writer.add_figure(title, figure, step) - def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613 - yield + def add_artifact(self, file_or_dir: str, name: str, artifact_type, aliases=None) -> None: + pass - def add_scalars(self, scope_name, scalars, step): + def add_scalars(self, scope_name: str, scalars, step: int) -> None: for key, value in scalars.items(): self.add_scalar("{}/{}".format(scope_name, key), value, step) - def add_figures(self, scope_name, figures, step): + def add_figures(self, scope_name: str, figures, step: int) -> None: for key, value in figures.items(): self.writer.add_figure("{}/{}".format(scope_name, key), value, step) - def add_audios(self, scope_name, audios, step, sample_rate): + def add_audios(self, scope_name: str, audios, step: int, sample_rate: int) -> None: for key, value in audios.items(): if value.dtype == "float16": value = value.astype("float32") @@ -64,8 +66,8 @@ def add_audios(self, scope_name, audios, step, sample_rate): except RuntimeError: traceback.print_exc() - def flush(self): + def flush(self) -> None: self.writer.flush() - def finish(self): + def finish(self) -> None: self.writer.close() diff --git a/trainer/logging/wandb_logger.py b/trainer/logging/wandb_logger.py index 903ee89..96e50bf 100644 --- a/trainer/logging/wandb_logger.py +++ b/trainer/logging/wandb_logger.py @@ -3,7 +3,7 @@ import traceback from collections import defaultdict from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING, Union from trainer.logging.base_dash_logger import BaseDashboardLogger from trainer.trainer_utils import is_wandb_available @@ -12,6 +12,11 @@ if is_wandb_available(): import wandb # pylint: disable=import-error +if TYPE_CHECKING: + import matplotlib + import numpy as np + import plotly + class WandbLogger(BaseDashboardLogger): def __init__(self, **kwargs): diff --git a/trainer/model.py b/trainer/model.py index 9dfd642..14ee778 100644 --- a/trainer/model.py +++ b/trainer/model.py @@ -1,15 +1,17 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch from torch import nn -from trainer.trainer import Trainer from trainer.trainer_utils import is_apex_available if is_apex_available(): from apex import amp +if TYPE_CHECKING: + from trainer.trainer import Trainer + # pylint: skip-file @@ -108,7 +110,7 @@ def get_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader: """Get data loader for the model. Args: - config (Coqpit): Configuration object. + config (TrainerConfig): Configuration object. assets (Dict): Additional assets to be used for data loading. is_eval (bool): If True, returns evaluation data loader. samples (Union[List[Dict], List[List]]): List of samples to be used for data loading. @@ -145,7 +147,12 @@ def optimize(self, *args: Any, **kwargs: Any) -> tuple[dict, dict, float]: raise NotImplementedError(" [!] `optimize()` is not implemented.") def scaled_backward( - self, loss: torch.Tensor, trainer: Trainer, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any + self, + loss: torch.Tensor, + trainer: "Trainer", + optimizer: torch.optim.Optimizer, + *args: Any, + **kwargs: Any, ) -> tuple[float, bool]: """Backward pass with gradient scaling for custom `optimize` calls. diff --git a/trainer/torch.py b/trainer/torch.py index 17f3489..e8f17e2 100644 --- a/trainer/torch.py +++ b/trainer/torch.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from typing import Optional import numpy as np @@ -35,7 +36,7 @@ def __init__( rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, - ): + ) -> None: super().__init__( sampler, num_replicas=num_replicas, @@ -44,7 +45,7 @@ def __init__( seed=seed, ) - def __iter__(self): + def __iter__(self) -> Iterator: indices = list(self.dataset)[: self.total_size] # Add extra samples to make it evenly divisible @@ -58,27 +59,27 @@ def __iter__(self): return iter(indices) - def set_epoch(self, epoch): + def set_epoch(self, epoch: int) -> None: super().set_epoch(epoch) if hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(epoch) elif hasattr(self.dataset, "generator"): self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch) - def state_dict(self): + def state_dict(self) -> dict: return self.dataset.state_dict() - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: dict) -> None: self.dataset.load_state_dict(state_dict) # pylint: disable=protected-access class NoamLR(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): + def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: float = 0.1, last_epoch: int = -1): self.warmup_steps = float(warmup_steps) super().__init__(optimizer, last_epoch) - def get_lr(self): + def get_lr(self) -> list[float]: step = max(self.last_epoch, 1) return [ base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) @@ -91,7 +92,7 @@ class StepwiseGradualLR(torch.optim.lr_scheduler._LRScheduler): """Hardcoded step-wise learning rate scheduling. Necessary for CapacitronVAE""" - def __init__(self, optimizer, gradual_learning_rates, last_epoch=-1): + def __init__(self, optimizer: torch.optim.Optimizer, gradual_learning_rates, last_epoch: int = -1) -> None: self.gradual_learning_rates = gradual_learning_rates super().__init__(optimizer, last_epoch) diff --git a/trainer/trainer.py b/trainer/trainer.py index 04affe6..ff2f068 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -1,6 +1,6 @@ import functools import gc -import importlib +import importlib.util import logging import os import platform @@ -8,13 +8,13 @@ import sys import time import traceback +from collections.abc import Generator, Iterable from contextlib import nullcontext from inspect import signature from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist -from coqpit import Coqpit from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader @@ -42,6 +42,7 @@ ) from trainer.logging import ConsoleLogger, DummyLogger, logger_factory from trainer.logging.base_dash_logger import BaseDashboardLogger +from trainer.model import TrainerModel from trainer.trainer_utils import ( get_optimizer, get_scheduler, @@ -77,14 +78,14 @@ def __init__( # pylint: disable=dangerous-default-value *, c_logger: Optional[ConsoleLogger] = None, dashboard_logger: Optional[BaseDashboardLogger] = None, - model: Optional[nn.Module] = None, + model: Optional[TrainerModel] = None, get_model: Optional[Callable] = None, get_data_samples: Optional[Callable] = None, train_samples: Optional[list] = None, eval_samples: Optional[list] = None, test_samples: Optional[list] = None, - train_loader: DataLoader = None, - eval_loader: DataLoader = None, + train_loader: Optional[DataLoader] = None, + eval_loader: Optional[DataLoader] = None, training_assets: Optional[dict] = None, parse_command_line_args: bool = True, callbacks: Optional[dict[str, Callable]] = None, @@ -101,10 +102,10 @@ def __init__( # pylint: disable=dangerous-default-value Args: - args (Union[Coqpit, Namespace]): Training arguments parsed either from console by `argparse` or `TrainerArgs` + args (TrainerArgs): Training arguments parsed either from console by `argparse` or `TrainerArgs` config object. - config (Coqpit): Model config object. It includes all the values necessary for initializing, training, evaluating + config (TrainerConfig): Model config object. It includes all the values necessary for initializing, training, evaluating and testing the model. output_path (str or Path, optional): Path to the output training folder. All @@ -116,7 +117,7 @@ def __init__( # pylint: disable=dangerous-default-value dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used. Defaults to None. - model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` + model (TrainerModel, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` initializes a model from the provided config. Defaults to None. get_model (Callable): @@ -189,7 +190,7 @@ def __init__( # pylint: disable=dangerous-default-value # get ready for training and parse command-line arguments to override the model config config, new_fields = self.init_training(args, coqpit_overrides, config) elif args.continue_path or args.restore_path: - config, new_fields = self.init_training(args, {}, config) + config, new_fields = self.init_training(args, [], config) else: new_fields = {} @@ -241,13 +242,16 @@ def __init__( # pylint: disable=dangerous-default-value self.epochs_done = 0 self.restore_step = 0 self.restore_epoch = 0 - self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None} - self.train_loader = None - self.test_loader = None - self.eval_loader = None + self.best_loss: Union[float, dict[str, Optional[float]]] = { + "train_loss": float("inf"), + "eval_loss": float("inf") if self.config.run_eval else None, + } + self.train_loader: Optional[DataLoader] = None + self.test_loader: Optional[DataLoader] = None + self.eval_loader: Optional[DataLoader] = None - self.keep_avg_train = None - self.keep_avg_eval = None + self.keep_avg_train: Optional[KeepAverage] = None + self.keep_avg_eval: Optional[KeepAverage] = None self.use_amp_scaler = ( self.use_cuda @@ -281,12 +285,12 @@ def __init__( # pylint: disable=dangerous-default-value self.setup_small_run(args.small_run) # init the model - if model is None and get_model is None: - raise ValueError("[!] `model` and `get_model` cannot both be None.") if model is not None: self.model = model - else: + elif get_model is not None: self.run_get_model(self.config, get_model) + else: + raise ValueError("[!] `model` and `get_model` cannot both be None.") # init model's training assets if isimplemented(self.model, "init_for_training"): @@ -310,9 +314,9 @@ def __init__( # pylint: disable=dangerous-default-value self.model.cuda() if isinstance(self.criterion, list): for criterion in self.criterion: - if isinstance(criterion, torch.nn.Module): + if isinstance(criterion, nn.Module): criterion.cuda() - elif isinstance(self.criterion, torch.nn.Module): + elif isinstance(self.criterion, nn.Module): self.criterion.cuda() # setup optimizer @@ -396,14 +400,22 @@ def setup_accelerate(self) -> None: precision=self.config.precision, ) - def prepare_accelerate_loader(self, data_loader): + def prepare_accelerate_loader(self, data_loader: DataLoader) -> DataLoader: """Prepare the accelerator for the training.""" if self.use_accelerate: return self.accelerator.prepare_data_loader(data_loader) return data_loader @staticmethod - def init_accelerate(model, optimizer, training_dataloader, scheduler, grad_accum_steps, mixed_precision, precision): + def init_accelerate( + model: TrainerModel, + optimizer: torch.optim.Optimizer, + training_dataloader: DataLoader, + scheduler, + grad_accum_steps, + mixed_precision: bool, + precision, + ) -> tuple: """Setup HF Accelerate for the training.""" # check if accelerate is installed @@ -420,7 +432,7 @@ def init_accelerate(model, optimizer, training_dataloader, scheduler, grad_accum elif _precision == "bfloat16": _precision = "bf16" accelerator = Accelerator(gradient_accumulation_steps=grad_accum_steps, mixed_precision=_precision) - if isinstance(model, torch.nn.Module): + if isinstance(model, nn.Module): model = accelerator.prepare_model(model) if isinstance(optimizer, dict): @@ -457,7 +469,12 @@ def save_training_script(self) -> None: shutil.copyfile(file_path, os.path.join(self.output_path, file_name)) @staticmethod - def init_loggers(config: "Coqpit", output_path: str, dashboard_logger=None, c_logger=None): + def init_loggers( + config: TrainerConfig, + output_path: str, + dashboard_logger: Optional[BaseDashboardLogger] = None, + c_logger: Optional[ConsoleLogger] = None, + ) -> tuple[BaseDashboardLogger, ConsoleLogger]: """Init console and dashboard loggers. Use the given logger if passed externally else use config values to pick the right logger. @@ -465,7 +482,7 @@ def init_loggers(config: "Coqpit", output_path: str, dashboard_logger=None, c_lo Define a console logger for each process in DDP Args: - config (Coqpit): Model config. + config (TrainerConfig): Model config. output_path (str): Output path to save the training artifacts. dashboard_logger (DashboardLogger): Object passed to the trainer from outside. c_logger (ConsoleLogger): Object passed to the trained from outside. @@ -492,30 +509,34 @@ def setup_small_run(self, small_run: Optional[int] = None) -> None: @staticmethod def init_training( - args: TrainerArgs, coqpit_overrides: dict, config: Coqpit = None - ) -> tuple[Coqpit, dict[str, str]]: + args: TrainerArgs, coqpit_overrides: list[str], config: Optional[TrainerConfig] = None + ) -> tuple[TrainerConfig, dict[str, str]]: """Initialize training and update model configs from command line arguments. Args: - args (argparse.Namespace or dict like): Parsed trainer arguments. - config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments. - config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. + args: Parsed trainer arguments. + config_overrides: Parsed config overriding arguments. + config: Model config. If none, it is generated from `args`. Defaults to None. Returns: - config (Coqpit): Config paramaters. + config (TrainerConfig): Config paramaters. """ # set arguments for continuing training if args.continue_path: - args.config_path = os.path.join(args.continue_path, "config.json") + config_path = os.path.join(args.continue_path, "config.json") args.restore_path, best_model = get_last_checkpoint(args.continue_path) if not args.best_path: args.best_path = best_model # use the same config if config: - config.load_json(args.config_path) + config.load_json(config_path) else: - coqpit = Coqpit() - coqpit.load_json(args.config_path) + config = TrainerConfig() + config.load_json(config_path) + + if config is None: + msg = "Config or continue_path containing Config not provided" + raise ValueError(msg) # override config values from command-line args # TODO: Maybe it is better to do it outside @@ -531,7 +552,7 @@ def init_training( return config, new_fields @staticmethod - def setup_training_environment(args, config, gpu) -> tuple[bool, int]: + def setup_training_environment(args: TrainerArgs, config: TrainerConfig, gpu: Optional[int]) -> tuple[bool, int]: if platform.system() != "Windows": # https://github.com/pytorch/pytorch/issues/973 import resource # pylint: disable=import-outside-toplevel @@ -555,25 +576,27 @@ def setup_training_environment(args, config, gpu) -> tuple[bool, int]: return use_cuda, num_gpus @staticmethod - def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module: + def run_get_model( + config: TrainerConfig, get_model: Union[Callable[[TrainerConfig], TrainerModel], Callable[[], TrainerModel]] + ) -> TrainerModel: """Run the `get_model` function and return the model. Args: - config (Coqpit): Model config. + config (TrainerConfig): Model config. Returns: - nn.Module: initialized model. + TrainerModel: initialized model. """ - if len(signature(get_model).sig.parameters) == 1: + if len(signature(get_model).parameters) == 1: model = get_model(config) else: model = get_model() return model @staticmethod - def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module: + def run_get_data_samples(config: TrainerConfig, get_data_samples: Callable) -> tuple[Iterable, Iterable]: if callable(get_data_samples): - if len(signature(get_data_samples).sig.parameters) == 1: + if len(signature(get_data_samples).parameters) == 1: train_samples, eval_samples = get_data_samples(config) else: train_samples, eval_samples = get_data_samples() @@ -582,23 +605,23 @@ def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Modul def restore_model( self, - config: Coqpit, + config: TrainerConfig, restore_path: Union[str, os.PathLike[Any]], - model: nn.Module, + model: TrainerModel, optimizer: torch.optim.Optimizer, scaler: Optional["torch.GradScaler"] = None, - ) -> tuple[nn.Module, torch.optim.Optimizer, "torch.GradScaler", int]: + ) -> tuple[TrainerModel, torch.optim.Optimizer, "torch.GradScaler", int, int]: """Restore training from an old run. It restores model, optimizer, AMP scaler and training stats. Args: - config (Coqpit): Model config. + config (TrainerConfig): Model config. restore_path (str): Path to the restored training run. - model (nn.Module): Model to restored. + model (TrainerModel): Model to restored. optimizer (torch.optim.Optimizer): Optimizer to restore. scaler (torch.GradScaler, optional): AMP scaler to restore. Defaults to None. Returns: - Tuple[nn.Module, torch.optim.Optimizer, torch.GradScaler, int]: [description] + Tuple[TrainerModel, torch.optim.Optimizer, torch.GradScaler, int, int]: [description] """ def _restore_list_objs(states, obj): @@ -641,7 +664,9 @@ def _restore_list_objs(states, obj): torch.cuda.empty_cache() return model, optimizer, scaler, restore_step, restore_epoch - def restore_lr(self, config, args, model, optimizer): + def restore_lr( + self, config: TrainerConfig, args: TrainerArgs, model: TrainerModel, optimizer: torch.optim.Optimizer + ) -> torch.optim.Optimizer: # use the same lr if continue training if not args.continue_path: if isinstance(optimizer, list): @@ -663,8 +688,8 @@ def restore_lr(self, config, args, model, optimizer): def _get_loader( self, - model: nn.Module, - config: Coqpit, + model: TrainerModel, + config: TrainerConfig, assets: dict, is_eval: bool, samples: list, @@ -803,7 +828,7 @@ def get_test_dataloader(self, training_assets: dict, samples: list, verbose: boo self.num_gpus, ) - def format_batch(self, batch: list) -> dict: + def format_batch(self, batch: Union[dict[str, Any], list]) -> dict: """Format the dataloader output and return a batch. 1. Call ```model.format_batch```. @@ -844,7 +869,7 @@ def format_batch(self, batch: list) -> dict: ###################### @staticmethod - def master_params(optimizer: torch.optim.Optimizer): + def master_params(optimizer: torch.optim.Optimizer) -> Generator: """Generator over parameters owned by the optimizer. Used to select parameters used by the optimizer for gradient clipping. @@ -857,13 +882,13 @@ def master_params(optimizer: torch.optim.Optimizer): @staticmethod def _model_train_step( - batch: dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None + batch: dict, model: TrainerModel, criterion: nn.Module, optimizer_idx: Optional[int] = None ) -> tuple[dict, dict]: """Perform a trainig forward step. Compute model outputs and losses. Args: batch (Dict): [description] - model (nn.Module): [description] + model (TrainerModel): [description] criterion (nn.Module): [description] optimizer_idx (int, optional): [description]. Defaults to None. @@ -878,7 +903,7 @@ def _model_train_step( return model.module.train_step(*input_args) return model.train_step(*input_args) - def _get_autocast_args(self, mixed_precision: bool, precision: str): + def _get_autocast_args(self, mixed_precision: bool, precision: str) -> tuple[str, torch.dtype]: device = "cpu" if is_pytorch_at_least_2_4(): dtype = torch.get_autocast_dtype("cpu") @@ -918,7 +943,14 @@ def detach_loss_dict( loss_dict_detached["grad_norm"] = grad_norm return loss_dict_detached - def _compute_loss(self, batch: dict, model: nn.Module, criterion: nn.Module, config: Coqpit, optimizer_idx: int): + def _compute_loss( + self, + batch: dict, + model: TrainerModel, + criterion: nn.Module, + config: TrainerConfig, + optimizer_idx: Optional[int], + ) -> tuple[dict, dict]: device, dtype = self._get_autocast_args(config.mixed_precision, config.precision) with torch.autocast(device_type=device, dtype=dtype, enabled=config.mixed_precision): if optimizer_idx is not None: @@ -928,7 +960,7 @@ def _compute_loss(self, batch: dict, model: nn.Module, criterion: nn.Module, con return outputs, loss_dict @staticmethod - def _set_grad_clip_per_optimizer(config: Coqpit, optimizer_idx: int): + def _set_grad_clip_per_optimizer(config: TrainerConfig, optimizer_idx: Optional[int]) -> float: # set gradient clipping threshold grad_clip = 0.0 # meaning no gradient clipping if "grad_clip" in config and config.grad_clip is not None: @@ -958,26 +990,26 @@ def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, sca def optimize( self, batch: dict, - model: nn.Module, + model: TrainerModel, optimizer: torch.optim.Optimizer, scaler: "torch.GradScaler", criterion: nn.Module, scheduler: Union[torch.optim.lr_scheduler._LRScheduler, list, dict], # pylint: disable=protected-access - config: Coqpit, + config: TrainerConfig, optimizer_idx: Optional[int] = None, step_optimizer: bool = True, num_optimizers: int = 1, - ) -> tuple[dict, dict, int]: + ) -> tuple[dict, dict, float]: """Perform a forward - backward pass and run the optimizer. Args: batch (Dict): Input batch. If - model (nn.Module): Model for training. Defaults to None. + model (TrainerModel): Model for training. Defaults to None. optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. scaler (AMPScaler): AMP scaler. criterion (nn.Module): Model's criterion. scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer. - config (Coqpit): Model config. + config (TrainerConfig): Model config. optimizer_idx (int, optional): Target optimizer being used. Defaults to None. step_optimizer (bool, optional): Whether step the optimizer. If False, gradients are accumulated and model parameters are not updated. Defaults to True. @@ -1141,7 +1173,7 @@ def train_step(self, batch: dict, batch_n_steps: int, step: int, loader_start_ti else: # auto training with multiple optimizers (e.g. GAN) outputs_per_optimizer = [None] * len(self.optimizer) - total_step_time = 0 + total_step_time = 0.0 for idx, optimizer in enumerate(self.optimizer): criterion = self.criterion # scaler = self.scaler[idx] if self.use_amp_scaler else None @@ -1315,13 +1347,13 @@ def train_epoch(self) -> None: ####################### def _model_eval_step( - self, batch: dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None + self, batch: dict, model: TrainerModel, criterion: nn.Module, optimizer_idx: Optional[int] = None ) -> tuple[dict, dict]: """Perform a evaluation forward pass. Compute model outputs and losses with no gradients. Args: batch (Dict): IBatch of inputs. - model (nn.Module): Model to call evaluation. + model (TrainerModel): Model to call evaluation. criterion (nn.Module): Model criterion. optimizer_idx (int, optional): Optimizer ID to define the closure in multi-optimizer training. Defaults to None. @@ -1342,7 +1374,7 @@ def _model_eval_step( return model.eval_step(*input_args) - def eval_step(self, batch: dict, step: int) -> tuple[dict, dict]: + def eval_step(self, batch: dict, step: int) -> tuple[Optional[dict], Optional[dict]]: """Perform a evaluation step on a batch of inputs and log the process. Args: @@ -1354,7 +1386,7 @@ def eval_step(self, batch: dict, step: int) -> tuple[dict, dict]: """ with torch.no_grad(): outputs = [] - loss_dict = {} + loss_dict: dict[str, Any] = {} if not isinstance(self.optimizer, list) or isimplemented(self.model, "optimize"): outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) if outputs is None: @@ -1505,14 +1537,14 @@ def _restore_best_loss(self) -> None: self.best_loss = {"train_loss": ch["model_loss"], "eval_loss": None} logger.info(" > Starting with loaded last best loss %s", self.best_loss) - def test(self, model=None, test_samples=None) -> None: + def test(self, model: Optional[TrainerModel] = None, test_samples: Optional[list[str]] = None) -> None: """Run evaluation steps on the test data split. You can either provide the model and the test samples explicitly or the trainer uses values from the initialization. Args: - model (nn.Module, optional): Model to use for testing. If None, use the model given in the initialization. + model (TrainerModel, optional): Model to use for testing. If None, use the model given in the initialization. Defaults to None. test_samples (List[str], optional): List of test samples to use for testing. If None, use the test samples @@ -1751,13 +1783,13 @@ def update_training_dashboard_logger(self, batch=None, outputs=None) -> None: ##################### @staticmethod - def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, list]: + def get_optimizer(model: TrainerModel, config: TrainerConfig) -> Union[torch.optim.Optimizer, list]: """Receive the optimizer from the model if model implements `get_optimizer()` else check the optimizer parameters in the config and try initiating the optimizer. Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. + model (TrainerModel): Training model. + config (TrainerConfig): Training configuration. Returns: Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list. @@ -1775,13 +1807,13 @@ def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimiz return optimizer @staticmethod - def get_lr(model: nn.Module, config: Coqpit) -> Union[float, list[float]]: + def get_lr(model: TrainerModel, config: TrainerConfig) -> Union[float, list[float]]: """Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate fromthe config. Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. + model (TrainerModel): Training model. + config (TrainerConfig): Training configuration. Returns: Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier. @@ -1798,14 +1830,14 @@ def get_lr(model: nn.Module, config: Coqpit) -> Union[float, list[float]]: @staticmethod def get_scheduler( - model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, list, dict] + model: TrainerModel, config: TrainerConfig, optimizer: Union[torch.optim.Optimizer, list, dict] ) -> Union[torch.optim.lr_scheduler._LRScheduler, list]: # pylint: disable=protected-access """Receive the scheduler from the model if model implements `get_scheduler()` else check the config and try initiating the scheduler. Args: - model (nn.Module): Training model. - config (Coqpit): Training configuration. + model (TrainerModel): Training model. + config (TrainerConfig): Training configuration. Returns: Union[torch.optim.Optimizer, List, Dict]: A scheduler or a list of schedulers, one for each optimizer. @@ -1829,8 +1861,8 @@ def get_scheduler( @staticmethod def restore_scheduler( scheduler: Union[torch.optim.lr_scheduler._LRScheduler, list, dict], - args: Coqpit, - config: Coqpit, + args: TrainerArgs, + config: TrainerConfig, restore_epoch: int, restore_step: int, ) -> Union[torch.optim.lr_scheduler._LRScheduler, list]: @@ -1857,11 +1889,11 @@ def restore_scheduler( return scheduler @staticmethod - def get_criterion(model: nn.Module) -> nn.Module: + def get_criterion(model: TrainerModel) -> nn.Module: """Receive the criterion from the model. Model must implement `get_criterion()`. Args: - model (nn.Module): Training model. + model (TrainerModel): Training model. Returns: nn.Module: Criterion layer. @@ -1890,7 +1922,7 @@ def _detach_loss_dict(loss_dict: dict) -> dict: loss_dict_detached[key] = value.detach().cpu().item() return loss_dict_detached - def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> dict: + def _pick_target_avg_loss(self, keep_avg_target: Optional[KeepAverage]) -> Optional[dict]: """Pick the target loss to compare models""" # if the keep_avg_target is None or empty return None diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py index 2a11c19..276714c 100644 --- a/trainer/trainer_utils.py +++ b/trainer/trainer_utils.py @@ -1,38 +1,41 @@ import importlib +import importlib.util import os import random +from collections.abc import Iterator from typing import Optional import numpy as np import torch +from torch.nn import Parameter -from trainer.config import TrainerArgs +from trainer.config import TrainerArgs, TrainerConfig from trainer.logger import logger from trainer.torch import NoamLR, StepwiseGradualLR from trainer.utils.distributed import rank_zero_logger_info -def is_apex_available(): +def is_apex_available() -> bool: return importlib.util.find_spec("apex") is not None -def is_mlflow_available(): +def is_mlflow_available() -> bool: return importlib.util.find_spec("mlflow") is not None -def is_aim_available(): +def is_aim_available() -> bool: return importlib.util.find_spec("aim") is not None -def is_wandb_available(): +def is_wandb_available() -> bool: return importlib.util.find_spec("wandb") is not None -def is_clearml_available(): +def is_clearml_available() -> bool: return importlib.util.find_spec("clearml") is not None -def print_training_env(args, config): +def print_training_env(args: TrainerArgs, config: TrainerConfig) -> None: """Print training environment.""" rank_zero_logger_info(" > Training Environment:", logger) @@ -67,7 +70,7 @@ def setup_torch_training_env( cudnn_benchmark: bool, cudnn_deterministic: bool, use_ddp: bool = False, - training_seed=54321, + training_seed: int = 54321, allow_tf32: bool = False, gpu=None, ) -> tuple[bool, int]: @@ -134,6 +137,7 @@ def get_scheduler( """ if lr_scheduler is None: return None + scheduler: type[torch.optim.lr_scheduler._LRScheduler] if lr_scheduler.lower() == "noamlr": scheduler = NoamLR elif lr_scheduler.lower() == "stepwisegraduallr": @@ -147,8 +151,8 @@ def get_optimizer( optimizer_name: str, optimizer_params: dict, lr: float, - model: torch.nn.Module = None, - parameters: Optional[list] = None, + model: Optional[torch.nn.Module] = None, + parameters: Optional[Iterator[Parameter]] = None, ) -> torch.optim.Optimizer: """Find, initialize and return a Torch optimizer. diff --git a/trainer/utils/cpu_memory.py b/trainer/utils/cpu_memory.py index 9d019ed..0f940b7 100644 --- a/trainer/utils/cpu_memory.py +++ b/trainer/utils/cpu_memory.py @@ -32,7 +32,7 @@ def set_cpu_memory_limit(num_gigabytes): pass -def is_out_of_cpu_memory(exception): +def is_out_of_cpu_memory(exception: Exception) -> bool: return ( isinstance(exception, RuntimeError) and len(exception.args) == 1 diff --git a/trainer/utils/cuda_memory.py b/trainer/utils/cuda_memory.py index 5e9c310..11a74c8 100644 --- a/trainer/utils/cuda_memory.py +++ b/trainer/utils/cuda_memory.py @@ -12,33 +12,33 @@ from trainer.utils.cpu_memory import is_out_of_cpu_memory -def gc_cuda(): +def gc_cuda() -> None: """Gargage collect Torch (CUDA) memory.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() -def get_cuda_total_memory(): +def get_cuda_total_memory() -> int: if torch.cuda.is_available(): return torch.cuda.get_device_properties(0).total_memory return 0 -def get_cuda_assumed_available_memory(): +def get_cuda_assumed_available_memory() -> int: if torch.cuda.is_available(): return get_cuda_total_memory() - torch.cuda.memory_reserved() return 0 -def get_cuda_available_memory(): +def get_cuda_available_memory() -> int: # Always allow for 1 GB overhead. if torch.cuda.is_available(): return get_cuda_assumed_available_memory() - get_cuda_blocked_memory() return 0 -def get_cuda_blocked_memory(): +def get_cuda_blocked_memory() -> int: if not torch.cuda.is_available(): return 0 @@ -60,7 +60,7 @@ def get_cuda_blocked_memory(): return available_memory - current_block -def is_cuda_out_of_memory(exception): +def is_cuda_out_of_memory(exception: Exception) -> bool: return ( isinstance(exception, (RuntimeError, torch.cuda.OutOfMemoryError)) and len(exception.args) == 1 @@ -68,7 +68,7 @@ def is_cuda_out_of_memory(exception): ) -def is_cudnn_snafu(exception): +def is_cudnn_snafu(exception: Exception) -> bool: # For/because of https://github.com/pytorch/pytorch/issues/4107 return ( isinstance(exception, RuntimeError) @@ -77,7 +77,7 @@ def is_cudnn_snafu(exception): ) -def cuda_meminfo(): +def cuda_meminfo() -> None: if not torch.cuda.is_available(): return @@ -91,5 +91,5 @@ def cuda_meminfo(): ) -def should_reduce_batch_size(exception): +def should_reduce_batch_size(exception: Exception) -> bool: return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) diff --git a/trainer/utils/distributed.py b/trainer/utils/distributed.py index bec31cb..3a06c67 100644 --- a/trainer/utils/distributed.py +++ b/trainer/utils/distributed.py @@ -1,4 +1,5 @@ # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py +import logging import os from functools import wraps from typing import Any, Callable, Optional @@ -7,7 +8,7 @@ import torch.distributed as dist -def is_dist_avail_and_initialized(): +def is_dist_avail_and_initialized() -> bool: if not dist.is_available(): return False if not dist.is_initialized(): @@ -15,7 +16,7 @@ def is_dist_avail_and_initialized(): return True -def get_rank(): +def get_rank() -> int: rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) @@ -24,7 +25,7 @@ def get_rank(): return 0 -def is_main_process(): +def is_main_process() -> bool: return get_rank() == 0 @@ -44,18 +45,18 @@ def rank_zero_print(message: str, *args, **kwargs) -> None: # pylint: disable=u @rank_zero_only -def rank_zero_logger_info(message: str, logger: "Logger", *args, **kwargs) -> None: # pylint: disable=unused-argument +def rank_zero_logger_info(message: str, logger: logging.Logger, *args, **kwargs) -> None: # pylint: disable=unused-argument logger.info(message) -def reduce_tensor(tensor, num_gpus): +def reduce_tensor(tensor: torch.Tensor, num_gpus: int) -> torch.Tensor: rt = tensor.clone() dist.all_reduce(rt, op=dist.reduce_op.SUM) rt /= num_gpus return rt -def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): +def init_distributed(rank: int, num_gpus: int, group_name: str, dist_backend, dist_url) -> None: assert torch.cuda.is_available(), "Distributed mode requires CUDA." # Set cuda device so everything is done on the right GPU.