Skip to content

Commit

Permalink
Merge pull request #10 from idiap/coqpit
Browse files Browse the repository at this point in the history
Switch to forked coqpit
  • Loading branch information
eginhard authored Nov 28, 2024
2 parents 8e4d2a1 + 5a2f065 commit eb7aa2a
Show file tree
Hide file tree
Showing 21 changed files with 314 additions and 214 deletions.
29 changes: 23 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ include = ["trainer*"]

[project]
name = "coqui-tts-trainer"
version = "0.1.7"
version = "0.2.0"
description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui."
readme = "README.md"
requires-python = ">=3.9, <3.13"
Expand All @@ -21,9 +21,7 @@ maintainers = [
classifiers = [
"Environment :: Console",
"Natural Language :: English",
# How mature is this project? Common values are
# 3 - Alpha, 4 - Beta, 5 - Production/Stable
"Development Status :: 3 - Alpha",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
Expand All @@ -39,7 +37,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"coqpit>=0.0.17",
"coqpit-config>=0.1.1",
"fsspec>=2023.6.0",
"numpy>=1.24.3; python_version < '3.12'",
"numpy>=1.26.0; python_version >= '3.12'",
Expand All @@ -58,11 +56,17 @@ dev = [
"pytest>=8",
"ruff==0.6.9",
]
# Dependencies for running the tests
test = [
"accelerate>=0.20.0",
"torchvision>=0.15.1",
]
mypy = [
"matplotlib>=3.9.2",
"mlflow>=2.18.0",
"mypy>=1.13.0",
"types-psutil>=6.1.0.20241102",
"wandb>=0.18.7",
]

[tool.uv]
default-groups = ["dev", "test"]
Expand Down Expand Up @@ -105,3 +109,16 @@ skip_empty = true
[tool.coverage.run]
source = ["trainer", "tests"]
command_line = "-m pytest"

[[tool.mypy.overrides]]
module = [
"accelerate",
"aim",
"aim.sdk.run",
"apex",
"clearml",
"fsspec",
"plotly",
"soundfile",
]
ignore_missing_imports = true
5 changes: 5 additions & 0 deletions tests/test_train_mnist.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

from tests.utils.mnist import MnistModel, MnistModelConfig
Expand All @@ -22,6 +23,7 @@ def test_train_mnist(tmp_path):

# Without parsing command line args
args = TrainerArgs()
args.small_run = 4

trainer2 = Trainer(
args,
Expand All @@ -48,3 +50,6 @@ def test_train_mnist(tmp_path):
loss4 = trainer3.keep_avg_train["avg_loss"]

assert loss3 > loss4

with pytest.raises(ValueError, match="cannot both be None"):
Trainer(args, MnistModelConfig(), output_path=tmp_path, model=None)
6 changes: 4 additions & 2 deletions trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib.metadata

from trainer.config import TrainerArgs, TrainerConfig
from trainer.model import *
from trainer.trainer import *
from trainer.model import TrainerModel
from trainer.trainer import Trainer

__version__ = importlib.metadata.version("coqui-tts-trainer")

__all__ = ["TrainerArgs", "TrainerConfig", "Trainer", "TrainerModel"]
25 changes: 14 additions & 11 deletions trainer/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from typing import Callable
from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
from trainer import Trainer


class TrainerCallback:
def __init__(self) -> None:
self.callbacks_on_init_start = []
self.callbacks_on_init_end = []
self.callbacks_on_epoch_start = []
self.callbacks_on_epoch_end = []
self.callbacks_on_train_epoch_start = []
self.callbacks_on_train_epoch_end = []
self.callbacks_on_train_step_start = []
self.callbacks_on_train_step_end = []
self.callbacks_on_keyboard_interrupt = []
self.callbacks_on_init_start: list[Callable] = []
self.callbacks_on_init_end: list[Callable] = []
self.callbacks_on_epoch_start: list[Callable] = []
self.callbacks_on_epoch_end: list[Callable] = []
self.callbacks_on_train_epoch_start: list[Callable] = []
self.callbacks_on_train_epoch_end: list[Callable] = []
self.callbacks_on_train_step_start: list[Callable] = []
self.callbacks_on_train_step_end: list[Callable] = []
self.callbacks_on_keyboard_interrupt: list[Callable] = []

def parse_callbacks_dict(self, callbacks_dict: dict[str, Callable]) -> None:
for key, value in callbacks_dict.items():
Expand All @@ -36,7 +39,7 @@ def parse_callbacks_dict(self, callbacks_dict: dict[str, Callable]) -> None:
else:
raise ValueError(f"Invalid callback key: {key}")

def on_init_start(self, trainer) -> None:
def on_init_start(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_start"):
trainer.model.module.on_init_start(trainer)
Expand Down
6 changes: 3 additions & 3 deletions trainer/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Any, Optional, Union

from coqpit import Coqpit

Expand Down Expand Up @@ -192,13 +192,13 @@ class TrainerConfig(Coqpit):
optimizer: Optional[Union[str, list[str]]] = field(
default=None, metadata={"help": "Optimizer(s) to use. Defaults to None"}
)
optimizer_params: Union[dict, list[dict]] = field(
optimizer_params: Union[dict[str, Any], list[dict[str, Any]]] = field(
default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"}
)
lr_scheduler: Optional[Union[str, list[str]]] = field(
default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"}
)
lr_scheduler_params: dict = field(
lr_scheduler_params: dict[str, Any] = field(
default_factory=dict, metadata={"help": "Learning rate scheduler(s) arguments. Defaults to {}"}
)
use_grad_scaler: bool = field(
Expand Down
5 changes: 3 additions & 2 deletions trainer/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import subprocess
import time

from trainer import TrainerArgs, logger
from trainer import TrainerArgs
from trainer.logger import logger


def distribute():
def distribute() -> None:
"""
Call 👟Trainer training script in DDP mode.
"""
Expand Down
26 changes: 14 additions & 12 deletions trainer/generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import datetime
import os
import subprocess
from collections.abc import ItemsView
from typing import Any, Union

import fsspec
import torch
from packaging.version import Version

from trainer.config import TrainerConfig
from trainer.logger import logger


Expand All @@ -20,7 +22,7 @@ def is_pytorch_at_least_2_4() -> bool:
return Version(torch.__version__) >= Version("2.4")


def isimplemented(obj, method_name) -> bool:
def isimplemented(obj: Any, method_name: str) -> bool:
"""Check if a method is implemented in a class."""
if method_name in dir(obj) and callable(getattr(obj, method_name)):
try:
Expand All @@ -43,7 +45,7 @@ def to_cuda(x: torch.Tensor) -> torch.Tensor:
return x


def get_cuda():
def get_cuda() -> tuple[bool, torch.device]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return use_cuda, device
Expand Down Expand Up @@ -97,7 +99,7 @@ def count_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def set_partial_state_dict(model_dict, checkpoint_state, c):
def set_partial_state_dict(model_dict: dict, checkpoint_state: dict, c: TrainerConfig) -> dict:
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k in checkpoint_state:
if k not in model_dict:
Expand All @@ -123,21 +125,21 @@ def set_partial_state_dict(model_dict, checkpoint_state, c):


class KeepAverage:
def __init__(self):
self.avg_values = {}
self.iters = {}
def __init__(self) -> None:
self.avg_values: dict[str, float] = {}
self.iters: dict[str, int] = {}

def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.avg_values[key]

def items(self):
def items(self) -> ItemsView[str, Any]:
return self.avg_values.items()

def add_value(self, name, init_val=0, init_iter=0):
def add_value(self, name: str, init_val: float = 0, init_iter: int = 0) -> None:
self.avg_values[name] = init_val
self.iters[name] = init_iter

def update_value(self, name, value, weighted_avg=False):
def update_value(self, name: str, value: float, weighted_avg: bool = False) -> None:
if name not in self.avg_values:
# add value if not exist before
self.add_value(name, init_val=value)
Expand All @@ -151,10 +153,10 @@ def update_value(self, name, value, weighted_avg=False):
self.iters[name] += 1
self.avg_values[name] /= self.iters[name]

def add_values(self, name_dict):
def add_values(self, name_dict: dict[str, float]) -> None:
for key, value in name_dict.items():
self.add_value(key, init_val=value)

def update_values(self, value_dict):
def update_values(self, value_dict: dict[str, float]) -> None:
for key, value in value_dict.items():
self.update_value(key, value)
16 changes: 8 additions & 8 deletions trainer/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import fsspec
import torch
from coqpit import Coqpit
from torch.types import Storage

from trainer.generic_utils import is_pytorch_at_least_2_4
from trainer.logger import logger
Expand Down Expand Up @@ -60,7 +61,7 @@ def copy_model_files(config: Coqpit, out_path: Union[str, os.PathLike[Any]], new

def load_fsspec(
path: Union[str, os.PathLike[Any]],
map_location: Union[str, Callable, torch.device, dict[Union[str, torch.device], Union[str, torch.device]]] = None,
map_location: Optional[Union[str, Callable[[Storage, str], Storage], torch.device, dict[str, str]]] = None,
cache: bool = True,
**kwargs,
) -> Any:
Expand Down Expand Up @@ -195,7 +196,7 @@ def save_checkpoint(

def save_best_model(
current_loss: Union[dict, float],
best_loss: Union[dict, float],
best_loss: Union[dict[str, Optional[float]], float],
config: Union[dict, Coqpit],
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
Expand All @@ -208,12 +209,13 @@ def save_best_model(
save_func: Optional[Callable] = None,
**kwargs,
) -> Union[dict, float]:
if isinstance(current_loss, dict):
if isinstance(current_loss, dict) and isinstance(best_loss, dict):
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
)
else:
assert isinstance(current_loss, float) and isinstance(best_loss, float)
is_save_model = current_loss < best_loss

is_save_model = is_save_model and current_step > keep_after
Expand Down Expand Up @@ -249,7 +251,7 @@ def save_best_model(
return best_loss


def get_last_checkpoint(path: Union[str, os.PathLike]) -> tuple[str, str]:
def get_last_checkpoint(path: Union[str, os.PathLike[Any]]) -> tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth` and the RegEx
Expand All @@ -274,7 +276,7 @@ def get_last_checkpoint(path: Union[str, os.PathLike]) -> tuple[str, str]:
# back if it exists on the path
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {}
last_model_nums = {}
last_model_nums: dict[str, int] = {}
for key in ["checkpoint", "best_model"]:
last_model_num = None
last_model = None
Expand Down Expand Up @@ -357,6 +359,4 @@ def sort_checkpoints(
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
return [checkpoint[1] for checkpoint in sorted(ordering_and_checkpoint_path)]
12 changes: 8 additions & 4 deletions trainer/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
import logging
import os
from typing import Union

from trainer.config import TrainerConfig
from trainer.logging.base_dash_logger import BaseDashboardLogger
from trainer.logging.console_logger import ConsoleLogger
from trainer.logging.dummy_logger import DummyLogger

# pylint: disable=import-outside-toplevel
__all__ = ["ConsoleLogger", "DummyLogger"]


logger = logging.getLogger("trainer")


def get_mlflow_tracking_url():
def get_mlflow_tracking_url() -> Union[str, None]:
if "MLFLOW_TRACKING_URI" in os.environ:
return os.environ["MLFLOW_TRACKING_URI"]
return None


def get_ai_repo_url():
def get_ai_repo_url() -> Union[str, None]:
if "AIM_TRACKING_URI" in os.environ:
return os.environ["AIM_TRACKING_URI"]
return None


def logger_factory(config, output_path):
def logger_factory(config: TrainerConfig, output_path: str) -> BaseDashboardLogger:
run_name = config.run_name
project_name = config.project_name
log_uri = config.logger_uri if config.logger_uri else output_path
dashboard_logger: BaseDashboardLogger

if config.dashboard_logger == "tensorboard":
from trainer.logging.tensorboard_logger import TensorboardLogger
Expand Down
Loading

0 comments on commit eb7aa2a

Please sign in to comment.