Skip to content

Commit

Permalink
fix(optim): add missing variable in Graft
Browse files Browse the repository at this point in the history
  • Loading branch information
ClashLuke committed Apr 23, 2023
1 parent d9b50a8 commit 7a12055
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
name='truegrad',
license='BSD',
description='PyTorch interface for TrueGrad-AdamW',
version='4.0.0',
version='4.0.1',
long_description=README,
url='https://github.com/clashluke/truegrad',
packages=setuptools.find_packages(),
Expand Down
13 changes: 10 additions & 3 deletions truegrad/optim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import warnings
from typing import Tuple, Union, List, Dict, Any, Optional

import torch
Expand Down Expand Up @@ -30,6 +29,7 @@ def __call__(self, mod: torch.optim.Optimizer):

class LpWeightDecay(WeightDecayBase):
def __init__(self, power: float):
super().__init__()
self.power = power

def __call__(self, mod: torch.optim.Optimizer, p: Tensor, idx: int):
Expand All @@ -46,12 +46,17 @@ def __init__(self):
super().__init__(1)


def _param_iterator(mod: torch.optim.Optimizer):
yield from (p.detach().clone() for group in mod.param_groups for p in group["params"])
def _detach(x: Tensor) -> Tensor:
return x.detach().clone()


def _param_iterator(mod: torch.optim.Optimizer, fn=_detach):
yield from (fn(p) for group in mod.param_groups for p in group["params"])


class WeightDecayToValue(WeightDecayBase):
def __init__(self):
super().__init__()
self.target_values: List[Tensor] = ...
self.global_step = 0

Expand Down Expand Up @@ -261,6 +266,8 @@ def step(self, closure=None):
original_params = list(_param_iterator(self))

self.magnitude.step()
params_flat = list(_param_iterator(self, lambda x: x))

magnitudes_flat = []
for o, p in zip(original_params, params_flat):
magnitudes_flat.append(torch.norm(o.double() - p.double()))
Expand Down

0 comments on commit 7a12055

Please sign in to comment.