From 7a12055085fe465c903f3891298e499586ddd4a9 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 23 Apr 2023 08:12:22 +0200 Subject: [PATCH] fix(optim): add missing variable in `Graft` --- setup.py | 2 +- truegrad/optim.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index d581e79..1383a0a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ name='truegrad', license='BSD', description='PyTorch interface for TrueGrad-AdamW', - version='4.0.0', + version='4.0.1', long_description=README, url='https://github.com/clashluke/truegrad', packages=setuptools.find_packages(), diff --git a/truegrad/optim.py b/truegrad/optim.py index 1aedb87..a998618 100644 --- a/truegrad/optim.py +++ b/truegrad/optim.py @@ -1,5 +1,4 @@ import functools -import warnings from typing import Tuple, Union, List, Dict, Any, Optional import torch @@ -30,6 +29,7 @@ def __call__(self, mod: torch.optim.Optimizer): class LpWeightDecay(WeightDecayBase): def __init__(self, power: float): + super().__init__() self.power = power def __call__(self, mod: torch.optim.Optimizer, p: Tensor, idx: int): @@ -46,12 +46,17 @@ def __init__(self): super().__init__(1) -def _param_iterator(mod: torch.optim.Optimizer): - yield from (p.detach().clone() for group in mod.param_groups for p in group["params"]) +def _detach(x: Tensor) -> Tensor: + return x.detach().clone() + + +def _param_iterator(mod: torch.optim.Optimizer, fn=_detach): + yield from (fn(p) for group in mod.param_groups for p in group["params"]) class WeightDecayToValue(WeightDecayBase): def __init__(self): + super().__init__() self.target_values: List[Tensor] = ... self.global_step = 0 @@ -261,6 +266,8 @@ def step(self, closure=None): original_params = list(_param_iterator(self)) self.magnitude.step() + params_flat = list(_param_iterator(self, lambda x: x)) + magnitudes_flat = [] for o, p in zip(original_params, params_flat): magnitudes_flat.append(torch.norm(o.double() - p.double()))