Skip to content

Commit

Permalink
tidy SWA callback
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Sep 21, 2023
1 parent e99bb9d commit 891cec2
Showing 1 changed file with 38 additions and 213 deletions.
251 changes: 38 additions & 213 deletions pvnet_summation/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"""
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union, cast
"""Stochastic weight averaging callback
Modified from:
lightning.pytorch.callbacks.StochasticWeightAveraging
"""

from typing import Any, Callable, cast, List, Optional, Union

import lightning.pytorch as pl
import torch
from torch import Tensor
from torch.optim.swa_utils import SWALR

import lightning.pytorch as pl
from lightning.fabric.utilities.types import LRScheduler
from lightning.pytorch.callbacks import StochasticWeightAveraging
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import LRSchedulerConfig
from torch import Tensor, nn
from torch.optim.swa_utils import SWALR
from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging

_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]
_DEFAULT_DEVICE = torch.device("cpu")

class StochasticWeightAveraging(_StochasticWeightAveraging):
"""Stochastic weight averaging callback
class StochasticWeightAveraging(StochasticWeightAveraging):
Modified from:
lightning.pytorch.callbacks.StochasticWeightAveraging
"""
def __init__(
self,
swa_lrs: Union[float, List[float]],
swa_epoch_start: Union[int, float] = 0.8,
annealing_epochs: int = 10,
annealing_strategy: str = "cos",
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
device: Optional[Union[torch.device, str]] = _DEFAULT_DEVICE,
):
r"""Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
Expand All @@ -55,13 +62,13 @@ def __init__(
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple
optimizers/schedulers.
.. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
Arguments:
swa_lrs: The SWA learning rate to use:
- ``float``. Use this value for all parameter groups of the optimizer.
Expand Down Expand Up @@ -89,96 +96,16 @@ def __init__(
(default: ``"cpu"``)
"""

err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:
raise MisconfigurationException(err_msg)
if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):
raise MisconfigurationException(err_msg)

wrong_type = not isinstance(swa_lrs, (float, list))
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
wrong_list = isinstance(swa_lrs, list) and not all(
lr > 0 and isinstance(lr, float) for lr in swa_lrs
)
if wrong_type or wrong_float or wrong_list:
raise MisconfigurationException(
"The `swa_lrs` should a positive float, or a list of positive floats"
)

if avg_fn is not None and not callable(avg_fn):
raise MisconfigurationException("The `avg_fn` should be callable.")

if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(
f"device is expected to be a torch.device or a str. Found {device}"
)

self.n_averaged: Optional[Tensor] = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: Optional["pl.LightningModule"] = None
self._initialized = False
self._swa_scheduler: Optional[LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
self._init_n_averaged = 0
self._latest_update_epoch = -1
self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {}
self._max_epochs: int
# Add this so we can use iterative datapipe
self._train_batches = 0

@property
def swa_start(self) -> int:
assert isinstance(self._swa_epoch_start, int)
return max(self._swa_epoch_start - 1, 0) # 0-based

@property
def swa_end(self) -> int:
return self._max_epochs - 1 # 0-based

@staticmethod
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
return any(
isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()

super()._init_(
swa_lrs, swa_epoch_start, annealing_epochs,
annealing_strategy, avg_fn, device,
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module)

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if len(trainer.optimizers) != 1:
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")

if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException(
"SWA currently not supported for more than 1 `lr_scheduler`."
)

assert trainer.max_epochs is not None
if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)

self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs += 1

if self._scheduler_state is not None:
self._clear_schedulers(trainer)

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Run at start of each train epoch"""
if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
self._initialized = True

Expand All @@ -190,7 +117,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
if isinstance(self._swa_lrs, float):
self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)

for lr, group in zip(self._swa_lrs, optimizer.param_groups):
for lr, group in zip(self._swa_lrs, optimizer.param_groups, strict=True):
group["initial_lr"] = lr

assert trainer.max_epochs is not None
Expand All @@ -212,7 +139,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
# as behaviour will be different compared to having checkpoint data.
rank_zero_warn(
"SWA is initializing after swa_start without any checkpoint data. "
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
"This may be caused by loading a checkpoint from an older version of PyTorch"
" Lightning."
)

# We assert that there is only one optimizer on fit start
Expand All @@ -236,7 +164,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

if self.n_averaged is None:
self.n_averaged = torch.tensor(
self._init_n_averaged, dtype=torch.long, device=pl_module.device
self._init_n_averaged,
dtype=torch.long,
device=pl_module.device
)

if (self.swa_start <= trainer.current_epoch <= self.swa_end) and (
Expand All @@ -258,123 +188,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

# There is no need to perform either backward or optimizer.step as we are
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase the number of training batches by 1 and skip backward.
# Therefore, we will virtually increase the number of training batches by 1 and
# skip backward.
trainer.fit_loop.max_batches += 1
trainer.fit_loop._skip_backward = True
self._accumulate_grad_batches = trainer.accumulate_grad_batches
trainer.accumulate_grad_batches = self._train_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
if trainer.current_epoch == 0:
"""Run at end of each train epoch"""
if trainer.current_epoch==0:
self._train_batches = trainer.global_step
trainer.fit_loop._skip_backward = False

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# the trainer increases the current epoch before this hook is called
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
# BatchNorm epoch update. Resetmax_batches, int), "Iterable-style datasets are not state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.fit_loop.max_batches -= 1
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch - 1 == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
assert self._average_model is not None
self.transfer_weights(self._average_model, pl_module)

@staticmethod
def transfer_weights(
src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"
) -> None:
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))

def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
self.momenta = {}
for module in pl_module.modules():
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
continue
assert module.running_mean is not None
module.running_mean = torch.zeros_like(
module.running_mean,
device=pl_module.device,
dtype=module.running_mean.dtype,
)
assert module.running_var is not None
module.running_var = torch.ones_like(
module.running_var,
device=pl_module.device,
dtype=module.running_var.dtype,
)
self.momenta[module] = module.momentum
module.momentum = None # type: ignore[assignment]
assert module.num_batches_tracked is not None
module.num_batches_tracked *= 0

def reset_momenta(self) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module] # type: ignore[assignment]

@staticmethod
def update_parameters(
average_model: "pl.LightningModule",
model: "pl.LightningModule",
n_averaged: Tensor,
avg_fn: _AVG_FN,
) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112."""
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
device = p_swa.device
p_swa_ = p_swa.detach()
p_model_ = p_model.detach().to(device)
src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))
p_swa_.copy_(src)
n_averaged += 1

@staticmethod
def avg_fn(
averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor
) -> Tensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (
num_averaged + 1
)

def state_dict(self) -> Dict[str, Any]:
return {
"n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
"latest_update_epoch": self._latest_update_epoch,
"scheduler_state": None
if self._swa_scheduler is None
else self._swa_scheduler.state_dict(),
"average_model_state": None
if self._average_model is None
else self._average_model.state_dict(),
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._init_n_averaged = state_dict["n_averaged"]
self._latest_update_epoch = state_dict["latest_update_epoch"]
self._scheduler_state = state_dict["scheduler_state"]
self._load_average_model_state(state_dict["average_model_state"])

@staticmethod
def _clear_schedulers(trainer: "pl.Trainer") -> None:
# If we have scheduler state saved, clear the scheduler configs so that we don't try to
# load state into the wrong type of schedulers when restoring scheduler checkpoint state.
# We'll configure the scheduler and re-load its state in on_train_epoch_start.
# Note that this relies on the callback state being restored before the scheduler state is
# restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of
# writing that is only True for deepspeed which is already not supported by SWA.
# See https://github.com/Lightning-AI/lightning/issues/11665 for background.
if trainer.lr_scheduler_configs:
assert len(trainer.lr_scheduler_configs) == 1
trainer.lr_scheduler_configs.clear()

def _load_average_model_state(self, model_state: Any) -> None:
if self._average_model is None:
return
self._average_model.load_state_dict(model_state)

0 comments on commit 891cec2

Please sign in to comment.