diff --git a/src/graphnet/training/callbacks.py b/src/graphnet/training/callbacks.py index a66255ca6..c319dccc1 100644 --- a/src/graphnet/training/callbacks.py +++ b/src/graphnet/training/callbacks.py @@ -1,20 +1,26 @@ """Callback class(es) for using during model training.""" import logging -from typing import Dict, List +import os +from typing import Dict, List, TYPE_CHECKING, Any, Optional import warnings import numpy as np +import torch from tqdm.std import Bar from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import TQDMProgressBar +from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping from pytorch_lightning.utilities import rank_zero_only from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from graphnet.utilities.logging import Logger +if TYPE_CHECKING: + from graphnet.models import Model + import pytorch_lightning as pl + class PiecewiseLinearLR(_LRScheduler): """Interpolate learning rate linearly between milestones.""" @@ -152,3 +158,92 @@ def on_train_epoch_end( h.setLevel(logging.ERROR) logger.info(str(super().train_progress_bar)) h.setLevel(level) + + +class GraphnetEarlyStopping(EarlyStopping): + """Early stopping callback for graphnet.""" + + def __init__(self, save_dir: str, **kwargs: Dict[str, Any]) -> None: + """Construct `GraphnetEarlyStopping` Callback. + + Args: + save_dir: Path to directory to save best model and config. + **kwargs: Keyword arguments to pass to `EarlyStopping`. See + `pytorch_lightning.callbacks.EarlyStopping` for details. + """ + self.save_dir = save_dir + super().__init__(**kwargs) + + def setup( + self, + trainer: "pl.Trainer", + graphnet_model: "Model", + stage: Optional[str] = None, + ) -> None: + """Call at setup stage of training. + + Args: + trainer: The trainer. + graphnet_model: The model. + stage: The stage of training. + """ + super().setup(trainer, graphnet_model, stage) + os.makedirs(self.save_dir, exist_ok=True) + graphnet_model.save_config(os.path.join(self.save_dir, "config.yml")) + + def on_train_epoch_end( + self, trainer: "pl.Trainer", graphnet_model: "Model" + ) -> None: + """Call after each train epoch. + + Args: + trainer: Trainer object. + graphnet_model: Graphnet Model. + + Returns: None. + """ + if not self._check_on_train_epoch_end or self._should_skip_check( + trainer + ): + return + current_best = self.best_score + self._run_early_stopping_check(trainer) + if self.best_score != current_best: + graphnet_model.save_state_dict( + os.path.join(self.save_dir, "best_model.pth") + ) + + def on_validation_end( + self, trainer: "pl.Trainer", graphnet_model: "Model" + ) -> None: + """Call after each validation epoch. + + Args: + trainer: Trainer object. + graphnet_model: Graphnet Model. + + Returns: None. + """ + if self._check_on_train_epoch_end or self._should_skip_check(trainer): + return + current_best = self.best_score + self._run_early_stopping_check(trainer) + if self.best_score != current_best: + graphnet_model.save_state_dict( + os.path.join(self.save_dir, "best_model.pth") + ) + + def on_fit_end( + self, trainer: "pl.Trainer", graphnet_model: "Model" + ) -> None: + """Call at the end of training. + + Args: + trainer: Trainer object. + graphnet_model: Graphnet Model. + + Returns: None. + """ + graphnet_model.load_state_dict( + os.path.join(self.save_dir, "best_model.pth") + )