From 891cec232cee2b17830c77601e3533ea2aed4c87 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 21 Sep 2023 12:20:12 +0000 Subject: [PATCH] tidy SWA callback --- pvnet_summation/callbacks.py | 251 ++++++----------------------------- 1 file changed, 38 insertions(+), 213 deletions(-) diff --git a/pvnet_summation/callbacks.py b/pvnet_summation/callbacks.py index e9ed5b7..31a672c 100644 --- a/pvnet_summation/callbacks.py +++ b/pvnet_summation/callbacks.py @@ -11,26 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -r"""Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^""" -from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Union, cast +"""Stochastic weight averaging callback + +Modified from: + lightning.pytorch.callbacks.StochasticWeightAveraging +""" + +from typing import Any, Callable, cast, List, Optional, Union -import lightning.pytorch as pl import torch +from torch import Tensor +from torch.optim.swa_utils import SWALR + +import lightning.pytorch as pl from lightning.fabric.utilities.types import LRScheduler -from lightning.pytorch.callbacks import StochasticWeightAveraging -from lightning.pytorch.strategies import DeepSpeedStrategy -from lightning.pytorch.strategies.fsdp import FSDPStrategy -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.types import LRSchedulerConfig -from torch import Tensor, nn -from torch.optim.swa_utils import SWALR +from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging _AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] +_DEFAULT_DEVICE = torch.device("cpu") +class StochasticWeightAveraging(_StochasticWeightAveraging): + """Stochastic weight averaging callback -class StochasticWeightAveraging(StochasticWeightAveraging): + Modified from: + lightning.pytorch.callbacks.StochasticWeightAveraging + """ def __init__( self, swa_lrs: Union[float, List[float]], @@ -38,7 +45,7 @@ def __init__( annealing_epochs: int = 10, annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, - device: Optional[Union[torch.device, str]] = torch.device("cpu"), + device: Optional[Union[torch.device, str]] = _DEFAULT_DEVICE, ): r"""Implements the Stochastic Weight Averaging (SWA) Callback to average a model. @@ -55,13 +62,13 @@ def __init__( .. warning:: This is an :ref:`experimental ` feature. - .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. + .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple + optimizers/schedulers. .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. - See also how to :ref:`enable it directly on the Trainer ` - Arguments: + swa_lrs: The SWA learning rate to use: - ``float``. Use this value for all parameter groups of the optimizer. @@ -89,96 +96,16 @@ def __init__( (default: ``"cpu"``) """ - - err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." - if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: - raise MisconfigurationException(err_msg) - if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): - raise MisconfigurationException(err_msg) - - wrong_type = not isinstance(swa_lrs, (float, list)) - wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 - wrong_list = isinstance(swa_lrs, list) and not all( - lr > 0 and isinstance(lr, float) for lr in swa_lrs - ) - if wrong_type or wrong_float or wrong_list: - raise MisconfigurationException( - "The `swa_lrs` should a positive float, or a list of positive floats" - ) - - if avg_fn is not None and not callable(avg_fn): - raise MisconfigurationException("The `avg_fn` should be callable.") - - if device is not None and not isinstance(device, (torch.device, str)): - raise MisconfigurationException( - f"device is expected to be a torch.device or a str. Found {device}" - ) - - self.n_averaged: Optional[Tensor] = None - self._swa_epoch_start = swa_epoch_start - self._swa_lrs = swa_lrs - self._annealing_epochs = annealing_epochs - self._annealing_strategy = annealing_strategy - self._avg_fn = avg_fn or self.avg_fn - self._device = device - self._model_contains_batch_norm: Optional[bool] = None - self._average_model: Optional["pl.LightningModule"] = None - self._initialized = False - self._swa_scheduler: Optional[LRScheduler] = None - self._scheduler_state: Optional[Dict] = None - self._init_n_averaged = 0 - self._latest_update_epoch = -1 - self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} - self._max_epochs: int + # Add this so we can use iterative datapipe self._train_batches = 0 - - @property - def swa_start(self) -> int: - assert isinstance(self._swa_epoch_start, int) - return max(self._swa_epoch_start - 1, 0) # 0-based - - @property - def swa_end(self) -> int: - return self._max_epochs - 1 # 0-based - - @staticmethod - def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: - return any( - isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules() + + super()._init_( + swa_lrs, swa_epoch_start, annealing_epochs, + annealing_strategy, avg_fn, device, ) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: - if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)): - raise MisconfigurationException("SWA does not currently support sharded models.") - - # copy the model before moving it to accelerator device. - self._average_model = deepcopy(pl_module) - - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if len(trainer.optimizers) != 1: - raise MisconfigurationException("SWA currently works with 1 `optimizer`.") - - if len(trainer.lr_scheduler_configs) > 1: - raise MisconfigurationException( - "SWA currently not supported for more than 1 `lr_scheduler`." - ) - - assert trainer.max_epochs is not None - if isinstance(self._swa_epoch_start, float): - self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) - - self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) - - self._max_epochs = trainer.max_epochs - if self._model_contains_batch_norm: - # virtually increase max_epochs to perform batch norm update on latest epoch. - assert trainer.fit_loop.max_epochs is not None - trainer.fit_loop.max_epochs += 1 - - if self._scheduler_state is not None: - self._clear_schedulers(trainer) - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Run at start of each train epoch""" if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): self._initialized = True @@ -190,7 +117,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if isinstance(self._swa_lrs, float): self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) - for lr, group in zip(self._swa_lrs, optimizer.param_groups): + for lr, group in zip(self._swa_lrs, optimizer.param_groups, strict=True): group["initial_lr"] = lr assert trainer.max_epochs is not None @@ -212,7 +139,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # as behaviour will be different compared to having checkpoint data. rank_zero_warn( "SWA is initializing after swa_start without any checkpoint data. " - "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." + "This may be caused by loading a checkpoint from an older version of PyTorch" + " Lightning." ) # We assert that there is only one optimizer on fit start @@ -236,7 +164,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.n_averaged is None: self.n_averaged = torch.tensor( - self._init_n_averaged, dtype=torch.long, device=pl_module.device + self._init_n_averaged, + dtype=torch.long, + device=pl_module.device ) if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( @@ -258,123 +188,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # There is no need to perform either backward or optimizer.step as we are # performing only one pass over the train data-loader to compute activation statistics - # Therefore, we will virtually increase the number of training batches by 1 and skip backward. + # Therefore, we will virtually increase the number of training batches by 1 and + # skip backward. trainer.fit_loop.max_batches += 1 trainer.fit_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches trainer.accumulate_grad_batches = self._train_batches def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: - if trainer.current_epoch == 0: + """Run at end of each train epoch""" + if trainer.current_epoch==0: self._train_batches = trainer.global_step trainer.fit_loop._skip_backward = False - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - # the trainer increases the current epoch before this hook is called - if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: - # BatchNorm epoch update. Resetmax_batches, int), "Iterable-style datasets are not state - trainer.accumulate_grad_batches = self._accumulate_grad_batches - trainer.fit_loop.max_batches -= 1 - assert trainer.fit_loop.max_epochs is not None - trainer.fit_loop.max_epochs -= 1 - self.reset_momenta() - elif trainer.current_epoch - 1 == self.swa_end: - # Last SWA epoch. Transfer weights from average model to pl_module - assert self._average_model is not None - self.transfer_weights(self._average_model, pl_module) - @staticmethod - def transfer_weights( - src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule" - ) -> None: - for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): - dst_param.detach().copy_(src_param.to(dst_param.device)) - - def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" - self.momenta = {} - for module in pl_module.modules(): - if not isinstance(module, nn.modules.batchnorm._BatchNorm): - continue - assert module.running_mean is not None - module.running_mean = torch.zeros_like( - module.running_mean, - device=pl_module.device, - dtype=module.running_mean.dtype, - ) - assert module.running_var is not None - module.running_var = torch.ones_like( - module.running_var, - device=pl_module.device, - dtype=module.running_var.dtype, - ) - self.momenta[module] = module.momentum - module.momentum = None # type: ignore[assignment] - assert module.num_batches_tracked is not None - module.num_batches_tracked *= 0 - - def reset_momenta(self) -> None: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" - for bn_module in self.momenta: - bn_module.momentum = self.momenta[bn_module] # type: ignore[assignment] - - @staticmethod - def update_parameters( - average_model: "pl.LightningModule", - model: "pl.LightningModule", - n_averaged: Tensor, - avg_fn: _AVG_FN, - ) -> None: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" - for p_swa, p_model in zip(average_model.parameters(), model.parameters()): - device = p_swa.device - p_swa_ = p_swa.detach() - p_model_ = p_model.detach().to(device) - src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device)) - p_swa_.copy_(src) - n_averaged += 1 - - @staticmethod - def avg_fn( - averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor - ) -> Tensor: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" - return averaged_model_parameter + (model_parameter - averaged_model_parameter) / ( - num_averaged + 1 - ) - def state_dict(self) -> Dict[str, Any]: - return { - "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), - "latest_update_epoch": self._latest_update_epoch, - "scheduler_state": None - if self._swa_scheduler is None - else self._swa_scheduler.state_dict(), - "average_model_state": None - if self._average_model is None - else self._average_model.state_dict(), - } - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self._init_n_averaged = state_dict["n_averaged"] - self._latest_update_epoch = state_dict["latest_update_epoch"] - self._scheduler_state = state_dict["scheduler_state"] - self._load_average_model_state(state_dict["average_model_state"]) - - @staticmethod - def _clear_schedulers(trainer: "pl.Trainer") -> None: - # If we have scheduler state saved, clear the scheduler configs so that we don't try to - # load state into the wrong type of schedulers when restoring scheduler checkpoint state. - # We'll configure the scheduler and re-load its state in on_train_epoch_start. - # Note that this relies on the callback state being restored before the scheduler state is - # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of - # writing that is only True for deepspeed which is already not supported by SWA. - # See https://github.com/Lightning-AI/lightning/issues/11665 for background. - if trainer.lr_scheduler_configs: - assert len(trainer.lr_scheduler_configs) == 1 - trainer.lr_scheduler_configs.clear() - - def _load_average_model_state(self, model_state: Any) -> None: - if self._average_model is None: - return - self._average_model.load_state_dict(model_state)