Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 21, 2023
1 parent 891cec2 commit f21aa03
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions pvnet_summation/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,28 @@
lightning.pytorch.callbacks.StochasticWeightAveraging
"""

from typing import Any, Callable, cast, List, Optional, Union

import torch
from torch import Tensor
from torch.optim.swa_utils import SWALR
from typing import Any, Callable, List, Optional, Union, cast

import lightning.pytorch as pl
import torch
from lightning.fabric.utilities.types import LRScheduler
from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import LRSchedulerConfig
from lightning.pytorch.callbacks import StochasticWeightAveraging as _StochasticWeightAveraging
from torch import Tensor
from torch.optim.swa_utils import SWALR

_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]
_DEFAULT_DEVICE = torch.device("cpu")


class StochasticWeightAveraging(_StochasticWeightAveraging):
"""Stochastic weight averaging callback
Modified from:
lightning.pytorch.callbacks.StochasticWeightAveraging
"""

def __init__(
self,
swa_lrs: Union[float, List[float]],
Expand All @@ -62,13 +63,12 @@ def __init__(
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple
optimizers/schedulers.
.. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.
Arguments:
swa_lrs: The SWA learning rate to use:
- ``float``. Use this value for all parameter groups of the optimizer.
Expand Down Expand Up @@ -98,10 +98,14 @@ def __init__(
"""
# Add this so we can use iterative datapipe
self._train_batches = 0

super()._init_(
swa_lrs, swa_epoch_start, annealing_epochs,
annealing_strategy, avg_fn, device,
swa_lrs,
swa_epoch_start,
annealing_epochs,
annealing_strategy,
avg_fn,
device,
)

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -164,9 +168,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

if self.n_averaged is None:
self.n_averaged = torch.tensor(
self._init_n_averaged,
dtype=torch.long,
device=pl_module.device
self._init_n_averaged, dtype=torch.long, device=pl_module.device
)

if (self.swa_start <= trainer.current_epoch <= self.swa_end) and (
Expand All @@ -188,7 +190,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

# There is no need to perform either backward or optimizer.step as we are
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase the number of training batches by 1 and
# Therefore, we will virtually increase the number of training batches by 1 and
# skip backward.
trainer.fit_loop.max_batches += 1
trainer.fit_loop._skip_backward = True
Expand All @@ -197,9 +199,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
"""Run at end of each train epoch"""
if trainer.current_epoch==0:
if trainer.current_epoch == 0:
self._train_batches = trainer.global_step
trainer.fit_loop._skip_backward = False



0 comments on commit f21aa03

Please sign in to comment.