Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More optimizer updates, add MARS, LaProp, add Adopt fix and more #2347

Merged
merged 10 commits into from
Nov 26, 2024
5 changes: 2 additions & 3 deletions hfdocs/source/reference/optimizers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@ This page contains the API reference documentation for learning rate optimizers
[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision
[[autodoc]] timm.optim.adahessian.Adahessian
[[autodoc]] timm.optim.adamp.AdamP
[[autodoc]] timm.optim.adamw.AdamW
[[autodoc]] timm.optim.adan.Adan
[[autodoc]] timm.optim.adopt.Adopt
[[autodoc]] timm.optim.lamb.Lamb
[[autodoc]] timm.optim.laprop.LaProp
[[autodoc]] timm.optim.lars.Lars
[[autodoc]] timm.optim.lion.Lion
[[autodoc]] timm.optim.lookahead.Lookahead
[[autodoc]] timm.optim.madgrad.MADGRAD
[[autodoc]] timm.optim.nadam.Nadam
[[autodoc]] timm.optim.mars.Mars
[[autodoc]] timm.optim.nadamw.NAdamW
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
[[autodoc]] timm.optim.radam.RAdam
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
[[autodoc]] timm.optim.sgdp.SGDP
[[autodoc]] timm.optim.sgdw.SGDW
16 changes: 8 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,21 @@ def test_model_inference(model_name, batch_size):
rand_output = model(rand_tensors['input'])
rand_features = model.forward_features(rand_tensors['input'])
rand_pre_logits = model.forward_head(rand_features, pre_logits=True)
assert torch.allclose(rand_output, rand_tensors['output'], rtol=1e-3, atol=1e-4)
assert torch.allclose(rand_features, rand_tensors['features'], rtol=1e-3, atol=1e-4)
assert torch.allclose(rand_pre_logits, rand_tensors['pre_logits'], rtol=1e-3, atol=1e-4)
assert torch.allclose(rand_output, rand_tensors['output'], rtol=1e-3, atol=1e-4), 'rand output does not match'
assert torch.allclose(rand_features, rand_tensors['features'], rtol=1e-3, atol=1e-4), 'rand features do not match'
assert torch.allclose(rand_pre_logits, rand_tensors['pre_logits'], rtol=1e-3, atol=1e-4), 'rand pre_logits do not match'

def _test_owl(owl_input):
def _test_owl(owl_input, tol=(1e-3, 1e-4)):
owl_output = model(owl_input)
owl_features = model.forward_features(owl_input)
owl_pre_logits = model.forward_head(owl_features.clone(), pre_logits=True)
assert owl_output.softmax(1).argmax(1) == 24 # owl
assert torch.allclose(owl_output, owl_tensors['output'], rtol=1e-3, atol=1e-4)
assert torch.allclose(owl_features, owl_tensors['features'], rtol=1e-3, atol=1e-4)
assert torch.allclose(owl_pre_logits, owl_tensors['pre_logits'], rtol=1e-3, atol=1e-4)
assert torch.allclose(owl_output, owl_tensors['output'], rtol=tol[0], atol=tol[1]), 'owl output does not match'
assert torch.allclose(owl_features, owl_tensors['features'], rtol=tol[0], atol=tol[1]), 'owl output does not match'
assert torch.allclose(owl_pre_logits, owl_tensors['pre_logits'], rtol=tol[0], atol=tol[1]), 'owl output does not match'

_test_owl(owl_tensors['input']) # test with original pp owl tensor
_test_owl(pp(test_owl).unsqueeze(0)) # re-process from original jpg
_test_owl(pp(test_owl).unsqueeze(0), tol=(1e-1, 1e-1)) # re-process from original jpg, Pillow output can change a lot btw ver


@pytest.mark.base
Expand Down
95 changes: 64 additions & 31 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@
These tests were adapted from PyTorch' optimizer tests.

"""
import math
import pytest
import functools
import importlib
import os
from copy import deepcopy

import pytest
import torch
from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter
from torch.testing._internal.common_utils import TestCase

from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo
from timm.optim import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler

import importlib
import os

torch_backend = os.environ.get('TORCH_BACKEND')
if torch_backend is not None:
importlib.import_module(torch_backend)
Expand Down Expand Up @@ -299,27 +297,39 @@ def test_optim_factory(optimizer):
opt_info = get_optimizer_info(optimizer)
assert isinstance(opt_info, OptimInfo)

if not opt_info.second_order: # basic tests don't support second order right now
# test basic cases that don't need specific tuning via factory test
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
)
lr = (1e-2,) * 4
if optimizer in ('mars',):
lr = (1e-3,) * 4

try:
if not opt_info.second_order: # basic tests don't support second order right now
# test basic cases that don't need specific tuning via factory test
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0])
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=lr[1]),
optimizer,
lr=lr[1] / 10)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=lr[2]),
optimizer,
lr=lr[2] / 10)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=lr[3]),
optimizer)
)
except TypeError as e:
if 'radamw' in optimizer:
pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.")
else:
raise e



#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
Expand Down Expand Up @@ -376,10 +386,17 @@ def test_adam(optimizer):

@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
def test_adopt(optimizer):
# FIXME rosenbrock is not passing for ADOPT
# _test_rosenbrock(
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
# )
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=3e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT


@pytest.mark.parametrize('optimizer', ['adan', 'adanw'])
def test_adan(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT


Expand Down Expand Up @@ -432,6 +449,14 @@ def test_lamb(optimizer):
_test_model(optimizer, dict(lr=1e-3))


@pytest.mark.parametrize('optimizer', ['laprop'])
def test_laprop(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
)
_test_model(optimizer, dict(lr=1e-2))


@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
def test_lars(optimizer):
_test_rosenbrock(
Expand All @@ -448,6 +473,14 @@ def test_madgrad(optimizer):
_test_model(optimizer, dict(lr=1e-2))


@pytest.mark.parametrize('optimizer', ['mars'])
def test_mars(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT


@pytest.mark.parametrize('optimizer', ['novograd'])
def test_novograd(optimizer):
_test_rosenbrock(
Expand Down
16 changes: 13 additions & 3 deletions timm/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian
from .adamp import AdamP
from .adamw import AdamW
from .adamw import AdamWLegacy
from .adan import Adan
from .adopt import Adopt
from .lamb import Lamb
from .laprop import LaProp
from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .mars import Mars
from .nadam import NAdamLegacy
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .radam import RAdamLegacy
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .sgdw import SGDW

# bring common torch.optim Optimizers into timm.optim namespace for consistency
from torch.optim import Adadelta, Adagrad, Adamax, Adam, AdamW, RMSprop, SGD
try:
# in case any very old torch versions being used
from torch.optim import NAdam, RAdam
except ImportError:
pass

from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
create_optimizer_v2, create_optimizer, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
66 changes: 53 additions & 13 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian
from .adamp import AdamP
from .adamw import AdamWLegacy
from .adan import Adan
from .adopt import Adopt
from .lamb import Lamb
from .laprop import LaProp
from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .mars import Mars
from .nadam import NAdamLegacy
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .radam import RAdamLegacy
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .sgdw import SGDW
Expand Down Expand Up @@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
OptimInfo(
name='adam',
opt_class=optim.Adam,
description='torch.optim Adam (Adaptive Moment Estimation)',
description='torch.optim.Adam, Adaptive Moment Estimation',
has_betas=True
),
OptimInfo(
name='adamw',
opt_class=optim.AdamW,
description='torch.optim Adam with decoupled weight decay regularization',
description='torch.optim.AdamW, Adam with decoupled weight decay',
has_betas=True
),
OptimInfo(
name='adamwlegacy',
opt_class=AdamWLegacy,
description='legacy impl of AdamW that pre-dates inclusion to torch.optim',
has_betas=True
),
OptimInfo(
Expand All @@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='nadam',
opt_class=Nadam,
description='Adam with Nesterov momentum',
opt_class=torch.optim.NAdam,
description='torch.optim.NAdam, Adam with Nesterov momentum',
has_betas=True
),
OptimInfo(
name='nadamlegacy',
opt_class=NAdamLegacy,
description='legacy impl of NAdam that pre-dates inclusion in torch.optim',
has_betas=True
),
OptimInfo(
name='nadamw',
opt_class=NAdamW,
description='Adam with Nesterov momentum and decoupled weight decay',
description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl',
has_betas=True
),
OptimInfo(
name='radam',
opt_class=RAdam,
description='Rectified Adam with variance adaptation',
opt_class=torch.optim.RAdam,
description='torch.optim.RAdam, Rectified Adam with variance adaptation',
has_betas=True
),
OptimInfo(
name='radamlegacy',
opt_class=RAdamLegacy,
description='legacy impl of RAdam that predates inclusion in torch.optim',
has_betas=True
),
OptimInfo(
name='radamw',
opt_class=torch.optim.RAdam,
description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay',
has_betas=True,
defaults={'decoupled_weight_decay': True}
),
OptimInfo(
name='adamax',
opt_class=optim.Adamax,
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
has_betas=True
),
OptimInfo(
Expand Down Expand Up @@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
OptimInfo(
name='adadelta',
opt_class=optim.Adadelta,
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
),
OptimInfo(
name='adagrad',
opt_class=optim.Adagrad,
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
defaults={'eps': 1e-8}
),
OptimInfo(
Expand All @@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
second_order=True,
),
OptimInfo(
name='laprop',
opt_class=LaProp,
description='Separating Momentum and Adaptivity in Adam',
has_betas=True,
),
OptimInfo(
name='lion',
opt_class=Lion,
Expand All @@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_momentum=True,
defaults={'decoupled_decay': True}
),
OptimInfo(
name='mars',
opt_class=Mars,
description='Unleashing the Power of Variance Reduction for Training Large Models',
has_betas=True,
),
OptimInfo(
name='novograd',
opt_class=NvNovoGrad,
Expand All @@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
OptimInfo(
name='rmsprop',
opt_class=optim.RMSprop,
description='torch.optim RMSprop, Root Mean Square Propagation',
description='torch.optim.RMSprop, Root Mean Square Propagation',
has_momentum=True,
defaults={'alpha': 0.9}
),
Expand Down
Loading