diff --git a/pvnet_summation/callbacks.py b/pvnet_summation/callbacks.py index 31a672c..71c0906 100644 --- a/pvnet_summation/callbacks.py +++ b/pvnet_summation/callbacks.py @@ -17,27 +17,28 @@ lightning.pytorch.callbacks.StochasticWeightAveraging """ -from typing import Any, Callable, cast, List, Optional, Union - -import torch -from torch import Tensor -from torch.optim.swa_utils import SWALR +from typing import Any, Callable, List, Optional, Union, cast import lightning.pytorch as pl +import torch from lightning.fabric.utilities.types import LRScheduler +from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.types import LRSchedulerConfig -from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging +from torch import Tensor +from torch.optim.swa_utils import SWALR _AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] _DEFAULT_DEVICE = torch.device("cpu") + class StochasticWeightAveraging(_StochasticWeightAveraging): """Stochastic weight averaging callback Modified from: lightning.pytorch.callbacks.StochasticWeightAveraging """ + def __init__( self, swa_lrs: Union[float, List[float]], @@ -62,13 +63,12 @@ def __init__( .. warning:: This is an :ref:`experimental ` feature. - .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple + .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. Arguments: - swa_lrs: The SWA learning rate to use: - ``float``. Use this value for all parameter groups of the optimizer. @@ -98,10 +98,14 @@ def __init__( """ # Add this so we can use iterative datapipe self._train_batches = 0 - + super()._init_( - swa_lrs, swa_epoch_start, annealing_epochs, - annealing_strategy, avg_fn, device, + swa_lrs, + swa_epoch_start, + annealing_epochs, + annealing_strategy, + avg_fn, + device, ) def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -164,9 +168,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.n_averaged is None: self.n_averaged = torch.tensor( - self._init_n_averaged, - dtype=torch.long, - device=pl_module.device + self._init_n_averaged, dtype=torch.long, device=pl_module.device ) if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( @@ -188,7 +190,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # There is no need to perform either backward or optimizer.step as we are # performing only one pass over the train data-loader to compute activation statistics - # Therefore, we will virtually increase the number of training batches by 1 and + # Therefore, we will virtually increase the number of training batches by 1 and # skip backward. trainer.fit_loop.max_batches += 1 trainer.fit_loop._skip_backward = True @@ -197,9 +199,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: """Run at end of each train epoch""" - if trainer.current_epoch==0: + if trainer.current_epoch == 0: self._train_batches = trainer.global_step trainer.fit_loop._skip_backward = False - - -