Skip to content

Commit

Permalink
Merge pull request #13 from idiap/ruff
Browse files Browse the repository at this point in the history
Ruff lint fixes
  • Loading branch information
eginhard authored Nov 29, 2024
2 parents eb7aa2a + 6b11596 commit 965321d
Show file tree
Hide file tree
Showing 28 changed files with 361 additions and 369 deletions.
9 changes: 0 additions & 9 deletions MANIFEST.in

This file was deleted.

22 changes: 9 additions & 13 deletions examples/train_mnist.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""
This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟.
"""
"""This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟."""

import os
from dataclasses import dataclass
from pathlib import Path

import torch
from torch import nn
Expand All @@ -27,7 +25,7 @@ class MnistModelConfig(TrainerConfig):


class MnistModel(TrainerModel):
def __init__(self):
def __init__(self) -> None:
super().__init__()

# mnist images are (1, 28, 28) (channels, height, width)
Expand All @@ -46,8 +44,7 @@ def forward(self, x):
x = F.relu(x)
x = self.layer_3(x)

x = F.log_softmax(x, dim=1)
return x
return F.log_softmax(x, dim=1)

def train_step(self, batch, criterion):
x, y = batch
Expand All @@ -65,13 +62,12 @@ def eval_step(self, batch, criterion):
def get_criterion():
return torch.nn.NLLLoss()

def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument
def get_data_loader(self, config, assets, *, is_eval, samples=None, verbose=False, num_gpus=1, rank=0): # pylint: disable=unused-argument
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
dataset = MNIST(Path.cwd(), train=not is_eval, download=True, transform=transform)
dataset.data = dataset.data[:256]
dataset.targets = dataset.targets[:256]
dataloader = DataLoader(dataset, batch_size=config.batch_size)
return dataloader
return DataLoader(dataset, batch_size=config.batch_size)


def main():
Expand All @@ -88,8 +84,8 @@ def main():
train_args,
config,
model=model,
train_samples=model.get_data_loader(config, None, False, None, None, None),
eval_samples=model.get_data_loader(config, None, True, None, None, None),
train_samples=model.get_data_loader(config, None, is_eval=False),
eval_samples=model.get_data_loader(config, None, is_eval=True),
parse_command_line_args=True,
)
trainer.fit()
Expand Down
30 changes: 13 additions & 17 deletions examples/train_simple_gan.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
This example shows training of a simple GAN model with MNIST dataset using Gradient Accumulation and Advanced
Optimization where you call optimizer steps manually.
"""This example shows training of a simple GAN model on the MNIST dataset.
Using Gradient Accumulation and Advanced Optimization where you call optimizer steps manually.
"""

import os
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
Expand All @@ -22,11 +22,11 @@


class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
def __init__(self, latent_dim, img_shape) -> None:
super().__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
def block(in_feat, out_feat, *, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
Expand All @@ -44,12 +44,11 @@ def block(in_feat, out_feat, normalize=True):

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
return img.view(img.size(0), *self.img_shape)


class Discriminator(nn.Module):
def __init__(self, img_shape):
def __init__(self, img_shape) -> None:
super().__init__()

self.model = nn.Sequential(
Expand All @@ -63,9 +62,7 @@ def __init__(self, img_shape):

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)

return validity
return self.model(img_flat)


@dataclass
Expand All @@ -76,7 +73,7 @@ class GANModelConfig(TrainerConfig):


class GANModel(TrainerModel):
def __init__(self):
def __init__(self) -> None:
super().__init__()
data_shape = (1, 28, 28)
self.generator = Generator(latent_dim=100, img_shape=data_shape)
Expand Down Expand Up @@ -153,11 +150,10 @@ def get_criterion(self):

def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
dataset = MNIST(Path.cwd(), train=not is_eval, download=True, transform=transform)
dataset.data = dataset.data[:64]
dataset.targets = dataset.targets[:64]
dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
return dataloader
return DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)


if __name__ == "__main__":
Expand All @@ -166,6 +162,6 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
config.grad_clip = None

model = GANModel()
trainer = Trainer(TrainerArgs(), config, model=model, gpu=0 if is_cuda else None)
trainer = Trainer(TrainerArgs(), config, model=model, output_path=Path.cwd(), gpu=0 if is_cuda else None)
trainer.config.epochs = 10
trainer.fit()
77 changes: 62 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
#
# ,*++++++*, ,*++++++*,
# *++. .+++ *++. .++*
# *+* ,++++* *+* *+* ,++++, *+*
# ,+, .++++++++++* ,++,,,,*+, ,++++++++++. *+,
# *+. .++++++++++++..++ *+.,++++++++++++. .+*
# .+* ++++++++++++.*+, .+*.++++++++++++ *+,
# .++ *++++++++* ++, .++.*++++++++* ++,
# ,+++*. . .*++, ,++*. .*+++*
# *+, .,*++**. .**++**. ,+*
# .+* *+,
# *+. Coqui .+*
# *+* +++ Trainer +++ *+*
# .+++*. . . *+++.
# ,+* *+++*... ...*+++* *+,
# .++. .""""+++++++****+++++++"""". ++.
# ,++. **** .++,
# .++* *++.
# *+++, ,+++*
# .,*++++::::::++++*,.
#

[tool.setuptools.packages.find]
include = ["trainer*"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "coqui-tts-trainer"
Expand Down Expand Up @@ -76,32 +95,60 @@ Homepage = "https://github.com/idiap/coqui-ai-Trainer"
Repository = "https://github.com/idiap/coqui-ai-Trainer"
Issues = "https://github.com/idiap/coqui-ai-Trainer/issues"

[tool.hatch.build]
exclude = [
"/.github",
"/.gitignore",
"/.pre-commit-config.yaml",
"/CODE_OF_CONDUCT.md",
"/CONTRIBUTING.md",
"/Makefile",
"/tests",
"/trainer/TODO.txt",
]

[tool.hatch.build.targets.wheel]
packages = ["trainer"]

[tool.ruff]
line-length = 120
target-version = "py39"
lint.extend-select = [
"B", # bugbear
"I", # import sorting
"ANN204", # type hints
"B", # bugbear
"D2", # docs
"D412",
"D415",
"EM", # error messages
"FBT", # boolean arguments
"FLY",
"I", # import sorting
"PIE",
"PLC",
"PLE",
"PLW",
"RUF",
"UP", # pyupgrade
"PL", # pylint
"RET", # return statements
"RUF", # ruff-specific
"UP", # pyupgrade
"SIM", # simplify
]

lint.ignore = [
"F821", # TODO: enable
"PLW2901", # TODO: enable
"UP032", # TODO: enable
"PLR09",
]

[tool.ruff.lint.per-file-ignores]
"**/__init__.py" = [
"F401", # init files may have "unused" imports for now
"F403", # init files may have star imports for now
"F401", # init files may have "unused" imports for now
"F403", # init files may have star imports for now
]
"tests/**" = [
"PLR2004", # magic values are ok in tests
]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.coverage.report]
show_missing = true
skip_empty = true
Expand Down
25 changes: 0 additions & 25 deletions setup.py

This file was deleted.

Loading

0 comments on commit 965321d

Please sign in to comment.