Skip to content

Commit

Permalink
fix(optim): self-graft correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
ClashLuke committed Jul 16, 2023
1 parent 9a43be0 commit d7c458b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
name='truegrad',
license='BSD',
description='PyTorch interface for TrueGrad-AdamW',
version='4.0.2',
version='4.0.3',
long_description=README,
url='https://github.com/clashluke/truegrad',
packages=setuptools.find_packages(),
Expand All @@ -26,6 +26,7 @@
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
'Intended Audience :: Developers',
Expand Down
5 changes: 3 additions & 2 deletions truegrad/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@ def step(self, closure=None):
update = p.double() - o.double()
p.set_(o)
scale = group["lr"]
sign_update = torch.sign(update)
if group["graft_to_self"]:
scale = scale * torch.norm(update)
p.add_(torch.sign(update), alpha=scale)
scale = scale * update.norm() / sign_update.norm().clamp(min=group["eps"])
p.add_(sign_update, alpha=scale)

return loss

Expand Down

0 comments on commit d7c458b

Please sign in to comment.