Skip to content

Commit

Permalink
Merge pull request #632 from AMHermansen/add-improved-earlystopping
Browse files Browse the repository at this point in the history
Add improved earlystopping
  • Loading branch information
AMHermansen authored Dec 1, 2023
2 parents 1b88096 + 1697e06 commit 800ebd9
Showing 1 changed file with 97 additions and 2 deletions.
99 changes: 97 additions & 2 deletions src/graphnet/training/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
"""Callback class(es) for using during model training."""

import logging
from typing import Dict, List
import os
from typing import Dict, List, TYPE_CHECKING, Any, Optional
import warnings

import numpy as np
import torch
from tqdm.std import Bar

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping
from pytorch_lightning.utilities import rank_zero_only
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from graphnet.utilities.logging import Logger

if TYPE_CHECKING:
from graphnet.models import Model
import pytorch_lightning as pl


class PiecewiseLinearLR(_LRScheduler):
"""Interpolate learning rate linearly between milestones."""
Expand Down Expand Up @@ -152,3 +158,92 @@ def on_train_epoch_end(
h.setLevel(logging.ERROR)
logger.info(str(super().train_progress_bar))
h.setLevel(level)


class GraphnetEarlyStopping(EarlyStopping):
"""Early stopping callback for graphnet."""

def __init__(self, save_dir: str, **kwargs: Dict[str, Any]) -> None:
"""Construct `GraphnetEarlyStopping` Callback.
Args:
save_dir: Path to directory to save best model and config.
**kwargs: Keyword arguments to pass to `EarlyStopping`. See
`pytorch_lightning.callbacks.EarlyStopping` for details.
"""
self.save_dir = save_dir
super().__init__(**kwargs)

def setup(
self,
trainer: "pl.Trainer",
graphnet_model: "Model",
stage: Optional[str] = None,
) -> None:
"""Call at setup stage of training.
Args:
trainer: The trainer.
graphnet_model: The model.
stage: The stage of training.
"""
super().setup(trainer, graphnet_model, stage)
os.makedirs(self.save_dir, exist_ok=True)
graphnet_model.save_config(os.path.join(self.save_dir, "config.yml"))

def on_train_epoch_end(
self, trainer: "pl.Trainer", graphnet_model: "Model"
) -> None:
"""Call after each train epoch.
Args:
trainer: Trainer object.
graphnet_model: Graphnet Model.
Returns: None.
"""
if not self._check_on_train_epoch_end or self._should_skip_check(
trainer
):
return
current_best = self.best_score
self._run_early_stopping_check(trainer)
if self.best_score != current_best:
graphnet_model.save_state_dict(
os.path.join(self.save_dir, "best_model.pth")
)

def on_validation_end(
self, trainer: "pl.Trainer", graphnet_model: "Model"
) -> None:
"""Call after each validation epoch.
Args:
trainer: Trainer object.
graphnet_model: Graphnet Model.
Returns: None.
"""
if self._check_on_train_epoch_end or self._should_skip_check(trainer):
return
current_best = self.best_score
self._run_early_stopping_check(trainer)
if self.best_score != current_best:
graphnet_model.save_state_dict(
os.path.join(self.save_dir, "best_model.pth")
)

def on_fit_end(
self, trainer: "pl.Trainer", graphnet_model: "Model"
) -> None:
"""Call at the end of training.
Args:
trainer: Trainer object.
graphnet_model: Graphnet Model.
Returns: None.
"""
graphnet_model.load_state_dict(
os.path.join(self.save_dir, "best_model.pth")
)

0 comments on commit 800ebd9

Please sign in to comment.