Skip to content

Commit

Permalink
train_model.py draft
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed May 31, 2024
1 parent 74bb3a5 commit 5f45083
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
Empty file added src/utils/functions/__init__.py
Empty file.
Empty file added src/utils/training/__init__.py
Empty file.
77 changes: 77 additions & 0 deletions src/utils/training/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Optional

import torch


def train_model(
model: torch.nn.Module,
train_loader: torch.utils.data.dataloader.DataLoader,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.modules.loss._Loss,
device: str = "cpu",
max_epochs: int = 100,
val_loader: torch.utils.data.dataloader.DataLoader = None,
early_stopping: bool = False,
early_stopping_kwargs: Optional[dict] = {"patience": 10},
verbose: bool = False,
*args,
**kwargs,
):
"""
Function to train a model.
Args:
model: torch.nn.Module: Model to train.
train_loader: torch.utils.data.dataloader.DataLoader: DataLoader for training data.
optimizer: torch.optim.Optimizer: Optimizer to use for training.
criterion: torch.nn.modules.loss._Loss: Loss function to use for training.
device: str: Device to use for training.
max_epochs: int: Maximum number of epochs to train for.
val_loader: torch.utils.data.dataloader.DataLoader: DataLoader for validation data.
early_stopping: bool: Whether to use early stopping.
patience: int: Patience for early stopping.
metric: str: Metric to use for early stopping.
verbose: bool: Whether to print training information.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
model: torch.nn.Module: Trained model.
"""
model.to(device)
if early_stopping:
assert val_loader is not None, "Validation loader is required for early stopping."
assert "metric" in early_stopping_kwargs, "Metric is required for early stopping."
assert "patience" in early_stopping_kwargs, "Patience is required for early stopping."
patience = early_stopping_kwargs["patience"]

no_improvement = 0
best_metric = None

for epoch in range(max_epochs):
model.train()
train_loss = 0

for i, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()

train_loss += loss.item()

if train_loss < best_metric:
best_metric = loss
no_improvement = 0
else:
no_improvement += 1

if early_stopping and no_improvement >= patience:
if verbose:
print(f"Early stopping at epoch {epoch}.")
break

if verbose: # TODO: tqdm
print(f"Epoch: {epoch}, Train Loss: {train_loss}")

0 comments on commit 5f45083

Please sign in to comment.