From 0ba00122e0fe7514d2d1d473be04229505255a33 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 21 Sep 2023 11:35:40 +0000 Subject: [PATCH 1/4] add stochastic weight averaging and rescale inputs --- configs/callbacks/default.yaml | 27 +-- configs/model/default.yaml | 4 +- configs/trainer/default.yaml | 2 +- pvnet_summation/callbacks.py | 356 ++++++++++++++++++++++++++++++++ pvnet_summation/models/model.py | 14 +- pvnet_summation/training.py | 3 + 6 files changed, 377 insertions(+), 29 deletions(-) create mode 100644 pvnet_summation/callbacks.py diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index d3e6224..2a5abeb 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -1,21 +1,3 @@ -#pretrain_early_stopping: -# _target_: pvnet.callbacks.PretrainEarlyStopping -# monitor: "MAE/val" # name of the logged metric which determines when model is improving -# mode: "min" # can be "max" or "min" -# patience: 10 # how many epochs (or val check periods) of not improving until training stops -# min_delta: 0.001 # minimum change in the monitored metric needed to qualify as an improvement - -#pretrain_encoder_freezing: -# _target_: pvnet.callbacks.PretrainFreeze - -early_stopping: - _target_: pvnet.callbacks.MainEarlyStopping - # name of the logged metric which determines when model is improving - monitor: "${resolve_monitor_loss:${model.output_quantiles}}" - mode: "min" # can be "max" or "min" - patience: 10 # how many epochs (or val check periods) of not improving until training stops - min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement - learning_rate_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor logging_interval: "epoch" @@ -37,6 +19,9 @@ model_checkpoint: dirpath: "checkpoints/pvnet_summation/${model_name}" #${..model_name} auto_insert_metric_name: False save_on_train_epoch_end: False -#device_stats_monitor: -# _target_: lightning.pytorch.callbacks.DeviceStatsMonitor -# cpu_stats: True + +stochastic_weight_averaging: + _target_: pvnet_summation.callbacks.StochasticWeightAveraging + swa_lrs: 0.0000001 + swa_epoch_start: 0.8 + annealing_epochs: 5 \ No newline at end of file diff --git a/configs/model/default.yaml b/configs/model/default.yaml index a2e8beb..0c05838 100644 --- a/configs/model/default.yaml +++ b/configs/model/default.yaml @@ -26,6 +26,6 @@ optimizer: lr: 0.0001 weight_decay: 0.25 amsgrad: True - patience: 5 + patience: 20 factor: 0.1 - threshold: 0.002 + threshold: 0.00 diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index dd0d7c4..3c75e63 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -5,7 +5,7 @@ accelerator: gpu devices: auto min_epochs: null -max_epochs: null +max_epochs: 100 reload_dataloaders_every_n_epochs: 0 num_sanity_val_steps: 8 fast_dev_run: false diff --git a/pvnet_summation/callbacks.py b/pvnet_summation/callbacks.py new file mode 100644 index 0000000..dc34461 --- /dev/null +++ b/pvnet_summation/callbacks.py @@ -0,0 +1,356 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^""" +from copy import deepcopy +from typing import Any, Callable, cast, Dict, List, Optional, Union + +import torch +from torch import nn, Tensor +from torch.optim.swa_utils import SWALR + +import lightning.pytorch as pl +from lightning.fabric.utilities.types import LRScheduler +from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.strategies import DeepSpeedStrategy +from lightning.pytorch.strategies.fsdp import FSDPStrategy +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities.types import LRSchedulerConfig +from lightning.pytorch.callbacks import StochasticWeightAveraging + +_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] + + +class StochasticWeightAveraging(StochasticWeightAveraging): + def __init__( + self, + swa_lrs: Union[float, List[float]], + swa_epoch_start: Union[int, float] = 0.8, + annealing_epochs: int = 10, + annealing_strategy: str = "cos", + avg_fn: Optional[_AVG_FN] = None, + device: Optional[Union[torch.device, str]] = torch.device("cpu"), + ): + r"""Implements the Stochastic Weight Averaging (SWA) Callback to average a model. + + Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to + Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii + Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson + (UAI 2018). + + This documentation is highly inspired by PyTorch's work on SWA. + The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package. + + For a SWA explanation, please take a look + `here `_. + + .. warning:: This is an :ref:`experimental ` feature. + + .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. + + .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. + + See also how to :ref:`enable it directly on the Trainer ` + + Arguments: + + swa_lrs: The SWA learning rate to use: + + - ``float``. Use this value for all parameter groups of the optimizer. + - ``List[float]``. A list values for each parameter group of the optimizer. + + swa_epoch_start: If provided as int, the procedure will start from + the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, + the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch + + annealing_epochs: number of epochs in the annealing phase (default: 10) + + annealing_strategy: Specifies the annealing strategy (default: "cos"): + + - ``"cos"``. For cosine annealing. + - ``"linear"`` For linear annealing + + avg_fn: the averaging function used to update the parameters; + the function must take in the current value of the + :class:`AveragedModel` parameter, the current value of :attr:`model` + parameter and the number of models already averaged; if None, + equally weighted average is used (default: ``None``) + + device: if provided, the averaged model will be stored on the ``device``. + When None is provided, it will infer the `device` from ``pl_module``. + (default: ``"cpu"``) + + """ + + err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." + if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: + raise MisconfigurationException(err_msg) + if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): + raise MisconfigurationException(err_msg) + + wrong_type = not isinstance(swa_lrs, (float, list)) + wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 + wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + if wrong_type or wrong_float or wrong_list: + raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") + + if avg_fn is not None and not callable(avg_fn): + raise MisconfigurationException("The `avg_fn` should be callable.") + + if device is not None and not isinstance(device, (torch.device, str)): + raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") + + self.n_averaged: Optional[Tensor] = None + self._swa_epoch_start = swa_epoch_start + self._swa_lrs = swa_lrs + self._annealing_epochs = annealing_epochs + self._annealing_strategy = annealing_strategy + self._avg_fn = avg_fn or self.avg_fn + self._device = device + self._model_contains_batch_norm: Optional[bool] = None + self._average_model: Optional["pl.LightningModule"] = None + self._initialized = False + self._swa_scheduler: Optional[LRScheduler] = None + self._scheduler_state: Optional[Dict] = None + self._init_n_averaged = 0 + self._latest_update_epoch = -1 + self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} + self._max_epochs: int + self._train_batches = 0 + + @property + def swa_start(self) -> int: + assert isinstance(self._swa_epoch_start, int) + return max(self._swa_epoch_start - 1, 0) # 0-based + + @property + def swa_end(self) -> int: + return self._max_epochs - 1 # 0-based + + @staticmethod + def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: + return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)): + raise MisconfigurationException("SWA does not currently support sharded models.") + + # copy the model before moving it to accelerator device. + self._average_model = deepcopy(pl_module) + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if len(trainer.optimizers) != 1: + raise MisconfigurationException("SWA currently works with 1 `optimizer`.") + + if len(trainer.lr_scheduler_configs) > 1: + raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") + + assert trainer.max_epochs is not None + if isinstance(self._swa_epoch_start, float): + self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) + + self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) + + self._max_epochs = trainer.max_epochs + if self._model_contains_batch_norm: + # virtually increase max_epochs to perform batch norm update on latest epoch. + assert trainer.fit_loop.max_epochs is not None + trainer.fit_loop.max_epochs += 1 + + if self._scheduler_state is not None: + self._clear_schedulers(trainer) + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): + self._initialized = True + + # move average model to request device. + assert self._average_model is not None + self._average_model = self._average_model.to(self._device or pl_module.device) + + optimizer = trainer.optimizers[0] + if isinstance(self._swa_lrs, float): + self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) + + for lr, group in zip(self._swa_lrs, optimizer.param_groups): + group["initial_lr"] = lr + + assert trainer.max_epochs is not None + self._swa_scheduler = cast( + LRScheduler, + SWALR( + optimizer, + swa_lr=self._swa_lrs, # type: ignore[arg-type] + anneal_epochs=self._annealing_epochs, + anneal_strategy=self._annealing_strategy, + last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, + ), + ) + if self._scheduler_state is not None: + # Restore scheduler state from checkpoint + self._swa_scheduler.load_state_dict(self._scheduler_state) + elif trainer.current_epoch != self.swa_start: + # Log a warning if we're initializing after start without any checkpoint data, + # as behaviour will be different compared to having checkpoint data. + rank_zero_warn( + "SWA is initializing after swa_start without any checkpoint data. " + "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." + ) + + # We assert that there is only one optimizer on fit start + default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler) + assert default_scheduler_cfg.interval == "epoch" + assert default_scheduler_cfg.frequency == 1 + + if trainer.lr_scheduler_configs: + scheduler_cfg = trainer.lr_scheduler_configs[0] + if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: + rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") + rank_zero_info( + f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" + f" for `{self._swa_scheduler.__class__.__name__}`" + ) + trainer.lr_scheduler_configs[0] = default_scheduler_cfg + else: + trainer.lr_scheduler_configs.append(default_scheduler_cfg) + + if self.n_averaged is None: + self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device) + + if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( + trainer.current_epoch > self._latest_update_epoch + ): + assert self.n_averaged is not None + assert self._average_model is not None + self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) + self._latest_update_epoch = trainer.current_epoch + + # Note: No > here in case the callback is saved with the model and training continues + if trainer.current_epoch == self.swa_end + 1: + # Transfer weights from average model to pl_module + assert self._average_model is not None + self.transfer_weights(self._average_model, pl_module) + + # Reset BatchNorm for update + self.reset_batch_norm_and_save_state(pl_module) + + # There is no need to perform either backward or optimizer.step as we are + # performing only one pass over the train data-loader to compute activation statistics + # Therefore, we will virtually increase the number of training batches by 1 and skip backward. + trainer.fit_loop.max_batches += 1 + trainer.fit_loop._skip_backward = True + self._accumulate_grad_batches = trainer.accumulate_grad_batches + trainer.accumulate_grad_batches = self._train_batches + + def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: + if trainer.current_epoch==0: + self._train_batches = trainer.global_step + trainer.fit_loop._skip_backward = False + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + # the trainer increases the current epoch before this hook is called + if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: + # BatchNorm epoch update. Resetmax_batches, int), "Iterable-style datasets are not state + trainer.accumulate_grad_batches = self._accumulate_grad_batches + trainer.fit_loop.max_batches -= 1 + assert trainer.fit_loop.max_epochs is not None + trainer.fit_loop.max_epochs -= 1 + self.reset_momenta() + elif trainer.current_epoch - 1 == self.swa_end: + # Last SWA epoch. Transfer weights from average model to pl_module + assert self._average_model is not None + self.transfer_weights(self._average_model, pl_module) + + @staticmethod + def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None: + for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): + dst_param.detach().copy_(src_param.to(dst_param.device)) + + def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None: + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" + self.momenta = {} + for module in pl_module.modules(): + if not isinstance(module, nn.modules.batchnorm._BatchNorm): + continue + assert module.running_mean is not None + module.running_mean = torch.zeros_like( + module.running_mean, + device=pl_module.device, + dtype=module.running_mean.dtype, + ) + assert module.running_var is not None + module.running_var = torch.ones_like( + module.running_var, + device=pl_module.device, + dtype=module.running_var.dtype, + ) + self.momenta[module] = module.momentum + module.momentum = None # type: ignore[assignment] + assert module.num_batches_tracked is not None + module.num_batches_tracked *= 0 + + def reset_momenta(self) -> None: + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" + for bn_module in self.momenta: + bn_module.momentum = self.momenta[bn_module] # type: ignore[assignment] + + @staticmethod + def update_parameters( + average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN + ) -> None: + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" + for p_swa, p_model in zip(average_model.parameters(), model.parameters()): + device = p_swa.device + p_swa_ = p_swa.detach() + p_model_ = p_model.detach().to(device) + src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device)) + p_swa_.copy_(src) + n_averaged += 1 + + @staticmethod + def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" + return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) + + def state_dict(self) -> Dict[str, Any]: + return { + "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), + "latest_update_epoch": self._latest_update_epoch, + "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), + "average_model_state": None if self._average_model is None else self._average_model.state_dict(), + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._init_n_averaged = state_dict["n_averaged"] + self._latest_update_epoch = state_dict["latest_update_epoch"] + self._scheduler_state = state_dict["scheduler_state"] + self._load_average_model_state(state_dict["average_model_state"]) + + @staticmethod + def _clear_schedulers(trainer: "pl.Trainer") -> None: + # If we have scheduler state saved, clear the scheduler configs so that we don't try to + # load state into the wrong type of schedulers when restoring scheduler checkpoint state. + # We'll configure the scheduler and re-load its state in on_train_epoch_start. + # Note that this relies on the callback state being restored before the scheduler state is + # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of + # writing that is only True for deepspeed which is already not supported by SWA. + # See https://github.com/Lightning-AI/lightning/issues/11665 for background. + if trainer.lr_scheduler_configs: + assert len(trainer.lr_scheduler_configs) == 1 + trainer.lr_scheduler_configs.clear() + + def _load_average_model_state(self, model_state: Any) -> None: + if self._average_model is None: + return + self._average_model.load_state_dict(model_state) \ No newline at end of file diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index 635dff4..ca355f1 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -28,7 +28,7 @@ def __init__( output_quantiles: Optional[list[float]] = None, output_network: AbstractLinearNetwork = DefaultFCNet, output_network_kwargs: Optional[dict] = None, - scale_pvnet_outputs: bool = False, + relative_scale_pvnet_outputs: bool = False, predict_difference_from_sum: bool = False, optimizer: AbstractOptimizer = _default_optimizer, ): @@ -42,7 +42,8 @@ def __init__( output_network: Pytorch Module class used to combine the 1D features to produce the forecast. output_network_kwargs: Dictionary of optional kwargs for the `output_network` module. - scale_pvnet_outputs: If true, the PVNet predictions are scaled by the capacities. + relative_scale_pvnet_outputs: If true, the PVNet predictions are scaled by a factor + which is proportional to their capacities. predict_difference_from_sum: Whether to use the sum of GSPs as an estimate for the national sum and train the model to correct this estimate. Otherwise the model tries to learn the national sum from the PVNet outputs directly. @@ -51,7 +52,7 @@ def __init__( super().__init__(model_name, model_version, optimizer, output_quantiles) - self.scale_pvnet_outputs = scale_pvnet_outputs + self.relative_scale_pvnet_outputs = relative_scale_pvnet_outputs self.predict_difference_from_sum = predict_difference_from_sum if output_network_kwargs is None: @@ -78,12 +79,15 @@ def forward(self, x): if "pvnet_outputs" not in x: x["pvnet_outputs"] = self.predict_pvnet_batch(x["pvnet_inputs"]) - if self.scale_pvnet_outputs: + if self.relative_scale_pvnet_outputs: if self.pvnet_model.use_quantile_regression: eff_cap = x["effective_capacity"].unsqueeze(-1) else: eff_cap = x["effective_capacity"] - x_in = x["pvnet_outputs"] * eff_cap + + # Multiply by (effective capacity / 100) since the capacities are roughly of magnitude + # of 100 MW. We still want the inputs to the network to be order of magnitude 1. + x_in = x["pvnet_outputs"] * (eff_cap/100) else: x_in = x["pvnet_outputs"] diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index d6d313c..315d84c 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -166,6 +166,9 @@ def train(config: DictConfig) -> Optional[float]: # Train the model completely trainer.fit(model=model, datamodule=datamodule) + + # Validate after end - useful if using stochastic weight averaging + trainer.validate(model=model, datamodule=datamodule) # Make sure everything closed properly log.info("Finalizing!") From e99bb9d0e8ff87607b913006c07423a723c678ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Sep 2023 11:39:44 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- configs/callbacks/default.yaml | 2 +- pvnet_summation/callbacks.py | 70 ++++++++++++++++++++++----------- pvnet_summation/models/model.py | 6 +-- pvnet_summation/training.py | 2 +- 4 files changed, 52 insertions(+), 28 deletions(-) diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index 2a5abeb..3f147ec 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -24,4 +24,4 @@ stochastic_weight_averaging: _target_: pvnet_summation.callbacks.StochasticWeightAveraging swa_lrs: 0.0000001 swa_epoch_start: 0.8 - annealing_epochs: 5 \ No newline at end of file + annealing_epochs: 5 diff --git a/pvnet_summation/callbacks.py b/pvnet_summation/callbacks.py index dc34461..e9ed5b7 100644 --- a/pvnet_summation/callbacks.py +++ b/pvnet_summation/callbacks.py @@ -13,21 +13,19 @@ # limitations under the License. r"""Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^""" from copy import deepcopy -from typing import Any, Callable, cast, Dict, List, Optional, Union - -import torch -from torch import nn, Tensor -from torch.optim.swa_utils import SWALR +from typing import Any, Callable, Dict, List, Optional, Union, cast import lightning.pytorch as pl +import torch from lightning.fabric.utilities.types import LRScheduler -from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.callbacks import StochasticWeightAveraging from lightning.pytorch.strategies import DeepSpeedStrategy from lightning.pytorch.strategies.fsdp import FSDPStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.types import LRSchedulerConfig -from lightning.pytorch.callbacks import StochasticWeightAveraging +from torch import Tensor, nn +from torch.optim.swa_utils import SWALR _AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] @@ -64,7 +62,6 @@ def __init__( See also how to :ref:`enable it directly on the Trainer ` Arguments: - swa_lrs: The SWA learning rate to use: - ``float``. Use this value for all parameter groups of the optimizer. @@ -101,15 +98,21 @@ def __init__( wrong_type = not isinstance(swa_lrs, (float, list)) wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 - wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + wrong_list = isinstance(swa_lrs, list) and not all( + lr > 0 and isinstance(lr, float) for lr in swa_lrs + ) if wrong_type or wrong_float or wrong_list: - raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") + raise MisconfigurationException( + "The `swa_lrs` should a positive float, or a list of positive floats" + ) if avg_fn is not None and not callable(avg_fn): raise MisconfigurationException("The `avg_fn` should be callable.") if device is not None and not isinstance(device, (torch.device, str)): - raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") + raise MisconfigurationException( + f"device is expected to be a torch.device or a str. Found {device}" + ) self.n_averaged: Optional[Tensor] = None self._swa_epoch_start = swa_epoch_start @@ -140,7 +143,9 @@ def swa_end(self) -> int: @staticmethod def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: - return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) + return any( + isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules() + ) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)): @@ -154,7 +159,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - raise MisconfigurationException("SWA currently works with 1 `optimizer`.") if len(trainer.lr_scheduler_configs) > 1: - raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") + raise MisconfigurationException( + "SWA currently not supported for more than 1 `lr_scheduler`." + ) assert trainer.max_epochs is not None if isinstance(self._swa_epoch_start, float): @@ -216,7 +223,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if trainer.lr_scheduler_configs: scheduler_cfg = trainer.lr_scheduler_configs[0] if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: - rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") + rank_zero_warn( + f"SWA is currently only supported every epoch. Found {scheduler_cfg}" + ) rank_zero_info( f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" f" for `{self._swa_scheduler.__class__.__name__}`" @@ -226,7 +235,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.lr_scheduler_configs.append(default_scheduler_cfg) if self.n_averaged is None: - self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device) + self.n_averaged = torch.tensor( + self._init_n_averaged, dtype=torch.long, device=pl_module.device + ) if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( trainer.current_epoch > self._latest_update_epoch @@ -254,7 +265,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.accumulate_grad_batches = self._train_batches def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: - if trainer.current_epoch==0: + if trainer.current_epoch == 0: self._train_batches = trainer.global_step trainer.fit_loop._skip_backward = False @@ -273,7 +284,9 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self.transfer_weights(self._average_model, pl_module) @staticmethod - def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None: + def transfer_weights( + src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule" + ) -> None: for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device)) @@ -307,7 +320,10 @@ def reset_momenta(self) -> None: @staticmethod def update_parameters( - average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN + average_model: "pl.LightningModule", + model: "pl.LightningModule", + n_averaged: Tensor, + avg_fn: _AVG_FN, ) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" for p_swa, p_model in zip(average_model.parameters(), model.parameters()): @@ -319,16 +335,24 @@ def update_parameters( n_averaged += 1 @staticmethod - def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: + def avg_fn( + averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor + ) -> Tensor: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" - return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) + return averaged_model_parameter + (model_parameter - averaged_model_parameter) / ( + num_averaged + 1 + ) def state_dict(self) -> Dict[str, Any]: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, - "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), - "average_model_state": None if self._average_model is None else self._average_model.state_dict(), + "scheduler_state": None + if self._swa_scheduler is None + else self._swa_scheduler.state_dict(), + "average_model_state": None + if self._average_model is None + else self._average_model.state_dict(), } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -353,4 +377,4 @@ def _clear_schedulers(trainer: "pl.Trainer") -> None: def _load_average_model_state(self, model_state: Any) -> None: if self._average_model is None: return - self._average_model.load_state_dict(model_state) \ No newline at end of file + self._average_model.load_state_dict(model_state) diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index ca355f1..fb0ad79 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -84,10 +84,10 @@ def forward(self, x): eff_cap = x["effective_capacity"].unsqueeze(-1) else: eff_cap = x["effective_capacity"] - - # Multiply by (effective capacity / 100) since the capacities are roughly of magnitude + + # Multiply by (effective capacity / 100) since the capacities are roughly of magnitude # of 100 MW. We still want the inputs to the network to be order of magnitude 1. - x_in = x["pvnet_outputs"] * (eff_cap/100) + x_in = x["pvnet_outputs"] * (eff_cap / 100) else: x_in = x["pvnet_outputs"] diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index 315d84c..73bcd2a 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -166,7 +166,7 @@ def train(config: DictConfig) -> Optional[float]: # Train the model completely trainer.fit(model=model, datamodule=datamodule) - + # Validate after end - useful if using stochastic weight averaging trainer.validate(model=model, datamodule=datamodule) From 891cec232cee2b17830c77601e3533ea2aed4c87 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 21 Sep 2023 12:20:12 +0000 Subject: [PATCH 3/4] tidy SWA callback --- pvnet_summation/callbacks.py | 251 ++++++----------------------------- 1 file changed, 38 insertions(+), 213 deletions(-) diff --git a/pvnet_summation/callbacks.py b/pvnet_summation/callbacks.py index e9ed5b7..31a672c 100644 --- a/pvnet_summation/callbacks.py +++ b/pvnet_summation/callbacks.py @@ -11,26 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -r"""Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^""" -from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Union, cast +"""Stochastic weight averaging callback + +Modified from: + lightning.pytorch.callbacks.StochasticWeightAveraging +""" + +from typing import Any, Callable, cast, List, Optional, Union -import lightning.pytorch as pl import torch +from torch import Tensor +from torch.optim.swa_utils import SWALR + +import lightning.pytorch as pl from lightning.fabric.utilities.types import LRScheduler -from lightning.pytorch.callbacks import StochasticWeightAveraging -from lightning.pytorch.strategies import DeepSpeedStrategy -from lightning.pytorch.strategies.fsdp import FSDPStrategy -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.types import LRSchedulerConfig -from torch import Tensor, nn -from torch.optim.swa_utils import SWALR +from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging _AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] +_DEFAULT_DEVICE = torch.device("cpu") +class StochasticWeightAveraging(_StochasticWeightAveraging): + """Stochastic weight averaging callback -class StochasticWeightAveraging(StochasticWeightAveraging): + Modified from: + lightning.pytorch.callbacks.StochasticWeightAveraging + """ def __init__( self, swa_lrs: Union[float, List[float]], @@ -38,7 +45,7 @@ def __init__( annealing_epochs: int = 10, annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, - device: Optional[Union[torch.device, str]] = torch.device("cpu"), + device: Optional[Union[torch.device, str]] = _DEFAULT_DEVICE, ): r"""Implements the Stochastic Weight Averaging (SWA) Callback to average a model. @@ -55,13 +62,13 @@ def __init__( .. warning:: This is an :ref:`experimental ` feature. - .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. + .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple + optimizers/schedulers. .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. - See also how to :ref:`enable it directly on the Trainer ` - Arguments: + swa_lrs: The SWA learning rate to use: - ``float``. Use this value for all parameter groups of the optimizer. @@ -89,96 +96,16 @@ def __init__( (default: ``"cpu"``) """ - - err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." - if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: - raise MisconfigurationException(err_msg) - if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): - raise MisconfigurationException(err_msg) - - wrong_type = not isinstance(swa_lrs, (float, list)) - wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 - wrong_list = isinstance(swa_lrs, list) and not all( - lr > 0 and isinstance(lr, float) for lr in swa_lrs - ) - if wrong_type or wrong_float or wrong_list: - raise MisconfigurationException( - "The `swa_lrs` should a positive float, or a list of positive floats" - ) - - if avg_fn is not None and not callable(avg_fn): - raise MisconfigurationException("The `avg_fn` should be callable.") - - if device is not None and not isinstance(device, (torch.device, str)): - raise MisconfigurationException( - f"device is expected to be a torch.device or a str. Found {device}" - ) - - self.n_averaged: Optional[Tensor] = None - self._swa_epoch_start = swa_epoch_start - self._swa_lrs = swa_lrs - self._annealing_epochs = annealing_epochs - self._annealing_strategy = annealing_strategy - self._avg_fn = avg_fn or self.avg_fn - self._device = device - self._model_contains_batch_norm: Optional[bool] = None - self._average_model: Optional["pl.LightningModule"] = None - self._initialized = False - self._swa_scheduler: Optional[LRScheduler] = None - self._scheduler_state: Optional[Dict] = None - self._init_n_averaged = 0 - self._latest_update_epoch = -1 - self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} - self._max_epochs: int + # Add this so we can use iterative datapipe self._train_batches = 0 - - @property - def swa_start(self) -> int: - assert isinstance(self._swa_epoch_start, int) - return max(self._swa_epoch_start - 1, 0) # 0-based - - @property - def swa_end(self) -> int: - return self._max_epochs - 1 # 0-based - - @staticmethod - def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: - return any( - isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules() + + super()._init_( + swa_lrs, swa_epoch_start, annealing_epochs, + annealing_strategy, avg_fn, device, ) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: - if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)): - raise MisconfigurationException("SWA does not currently support sharded models.") - - # copy the model before moving it to accelerator device. - self._average_model = deepcopy(pl_module) - - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if len(trainer.optimizers) != 1: - raise MisconfigurationException("SWA currently works with 1 `optimizer`.") - - if len(trainer.lr_scheduler_configs) > 1: - raise MisconfigurationException( - "SWA currently not supported for more than 1 `lr_scheduler`." - ) - - assert trainer.max_epochs is not None - if isinstance(self._swa_epoch_start, float): - self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) - - self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) - - self._max_epochs = trainer.max_epochs - if self._model_contains_batch_norm: - # virtually increase max_epochs to perform batch norm update on latest epoch. - assert trainer.fit_loop.max_epochs is not None - trainer.fit_loop.max_epochs += 1 - - if self._scheduler_state is not None: - self._clear_schedulers(trainer) - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Run at start of each train epoch""" if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): self._initialized = True @@ -190,7 +117,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if isinstance(self._swa_lrs, float): self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) - for lr, group in zip(self._swa_lrs, optimizer.param_groups): + for lr, group in zip(self._swa_lrs, optimizer.param_groups, strict=True): group["initial_lr"] = lr assert trainer.max_epochs is not None @@ -212,7 +139,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # as behaviour will be different compared to having checkpoint data. rank_zero_warn( "SWA is initializing after swa_start without any checkpoint data. " - "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." + "This may be caused by loading a checkpoint from an older version of PyTorch" + " Lightning." ) # We assert that there is only one optimizer on fit start @@ -236,7 +164,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.n_averaged is None: self.n_averaged = torch.tensor( - self._init_n_averaged, dtype=torch.long, device=pl_module.device + self._init_n_averaged, + dtype=torch.long, + device=pl_module.device ) if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( @@ -258,123 +188,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # There is no need to perform either backward or optimizer.step as we are # performing only one pass over the train data-loader to compute activation statistics - # Therefore, we will virtually increase the number of training batches by 1 and skip backward. + # Therefore, we will virtually increase the number of training batches by 1 and + # skip backward. trainer.fit_loop.max_batches += 1 trainer.fit_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches trainer.accumulate_grad_batches = self._train_batches def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: - if trainer.current_epoch == 0: + """Run at end of each train epoch""" + if trainer.current_epoch==0: self._train_batches = trainer.global_step trainer.fit_loop._skip_backward = False - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - # the trainer increases the current epoch before this hook is called - if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: - # BatchNorm epoch update. Resetmax_batches, int), "Iterable-style datasets are not state - trainer.accumulate_grad_batches = self._accumulate_grad_batches - trainer.fit_loop.max_batches -= 1 - assert trainer.fit_loop.max_epochs is not None - trainer.fit_loop.max_epochs -= 1 - self.reset_momenta() - elif trainer.current_epoch - 1 == self.swa_end: - # Last SWA epoch. Transfer weights from average model to pl_module - assert self._average_model is not None - self.transfer_weights(self._average_model, pl_module) - @staticmethod - def transfer_weights( - src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule" - ) -> None: - for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): - dst_param.detach().copy_(src_param.to(dst_param.device)) - - def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" - self.momenta = {} - for module in pl_module.modules(): - if not isinstance(module, nn.modules.batchnorm._BatchNorm): - continue - assert module.running_mean is not None - module.running_mean = torch.zeros_like( - module.running_mean, - device=pl_module.device, - dtype=module.running_mean.dtype, - ) - assert module.running_var is not None - module.running_var = torch.ones_like( - module.running_var, - device=pl_module.device, - dtype=module.running_var.dtype, - ) - self.momenta[module] = module.momentum - module.momentum = None # type: ignore[assignment] - assert module.num_batches_tracked is not None - module.num_batches_tracked *= 0 - - def reset_momenta(self) -> None: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" - for bn_module in self.momenta: - bn_module.momentum = self.momenta[bn_module] # type: ignore[assignment] - - @staticmethod - def update_parameters( - average_model: "pl.LightningModule", - model: "pl.LightningModule", - n_averaged: Tensor, - avg_fn: _AVG_FN, - ) -> None: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" - for p_swa, p_model in zip(average_model.parameters(), model.parameters()): - device = p_swa.device - p_swa_ = p_swa.detach() - p_model_ = p_model.detach().to(device) - src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device)) - p_swa_.copy_(src) - n_averaged += 1 - - @staticmethod - def avg_fn( - averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor - ) -> Tensor: - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" - return averaged_model_parameter + (model_parameter - averaged_model_parameter) / ( - num_averaged + 1 - ) - def state_dict(self) -> Dict[str, Any]: - return { - "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), - "latest_update_epoch": self._latest_update_epoch, - "scheduler_state": None - if self._swa_scheduler is None - else self._swa_scheduler.state_dict(), - "average_model_state": None - if self._average_model is None - else self._average_model.state_dict(), - } - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self._init_n_averaged = state_dict["n_averaged"] - self._latest_update_epoch = state_dict["latest_update_epoch"] - self._scheduler_state = state_dict["scheduler_state"] - self._load_average_model_state(state_dict["average_model_state"]) - - @staticmethod - def _clear_schedulers(trainer: "pl.Trainer") -> None: - # If we have scheduler state saved, clear the scheduler configs so that we don't try to - # load state into the wrong type of schedulers when restoring scheduler checkpoint state. - # We'll configure the scheduler and re-load its state in on_train_epoch_start. - # Note that this relies on the callback state being restored before the scheduler state is - # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of - # writing that is only True for deepspeed which is already not supported by SWA. - # See https://github.com/Lightning-AI/lightning/issues/11665 for background. - if trainer.lr_scheduler_configs: - assert len(trainer.lr_scheduler_configs) == 1 - trainer.lr_scheduler_configs.clear() - - def _load_average_model_state(self, model_state: Any) -> None: - if self._average_model is None: - return - self._average_model.load_state_dict(model_state) From f21aa03140f3cac0a5c1df5deaec27fb26a5fc20 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Sep 2023 12:20:30 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet_summation/callbacks.py | 37 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/pvnet_summation/callbacks.py b/pvnet_summation/callbacks.py index 31a672c..71c0906 100644 --- a/pvnet_summation/callbacks.py +++ b/pvnet_summation/callbacks.py @@ -17,27 +17,28 @@ lightning.pytorch.callbacks.StochasticWeightAveraging """ -from typing import Any, Callable, cast, List, Optional, Union - -import torch -from torch import Tensor -from torch.optim.swa_utils import SWALR +from typing import Any, Callable, List, Optional, Union, cast import lightning.pytorch as pl +import torch from lightning.fabric.utilities.types import LRScheduler +from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.types import LRSchedulerConfig -from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging +from torch import Tensor +from torch.optim.swa_utils import SWALR _AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] _DEFAULT_DEVICE = torch.device("cpu") + class StochasticWeightAveraging(_StochasticWeightAveraging): """Stochastic weight averaging callback Modified from: lightning.pytorch.callbacks.StochasticWeightAveraging """ + def __init__( self, swa_lrs: Union[float, List[float]], @@ -62,13 +63,12 @@ def __init__( .. warning:: This is an :ref:`experimental ` feature. - .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple + .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. Arguments: - swa_lrs: The SWA learning rate to use: - ``float``. Use this value for all parameter groups of the optimizer. @@ -98,10 +98,14 @@ def __init__( """ # Add this so we can use iterative datapipe self._train_batches = 0 - + super()._init_( - swa_lrs, swa_epoch_start, annealing_epochs, - annealing_strategy, avg_fn, device, + swa_lrs, + swa_epoch_start, + annealing_epochs, + annealing_strategy, + avg_fn, + device, ) def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -164,9 +168,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.n_averaged is None: self.n_averaged = torch.tensor( - self._init_n_averaged, - dtype=torch.long, - device=pl_module.device + self._init_n_averaged, dtype=torch.long, device=pl_module.device ) if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( @@ -188,7 +190,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # There is no need to perform either backward or optimizer.step as we are # performing only one pass over the train data-loader to compute activation statistics - # Therefore, we will virtually increase the number of training batches by 1 and + # Therefore, we will virtually increase the number of training batches by 1 and # skip backward. trainer.fit_loop.max_batches += 1 trainer.fit_loop._skip_backward = True @@ -197,9 +199,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: """Run at end of each train epoch""" - if trainer.current_epoch==0: + if trainer.current_epoch == 0: self._train_batches = trainer.global_step trainer.fit_loop._skip_backward = False - - -