Skip to content

Commit

Permalink
More fixes for new factory & tests, add back adahessian
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Nov 13, 2024
1 parent 5dae918 commit 0e6da65
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 31 deletions.
3 changes: 2 additions & 1 deletion hfdocs/source/reference/optimizers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ This page contains the API reference documentation for learning rate optimizers
[[autodoc]] timm.optim.adahessian.Adahessian
[[autodoc]] timm.optim.adamp.AdamP
[[autodoc]] timm.optim.adamw.AdamW
[[autodoc]] timm.optim.adan.Adan
[[autodoc]] timm.optim.adopt.Adopt
[[autodoc]] timm.optim.lamb.Lamb
[[autodoc]] timm.optim.lars.Lars
[[autodoc]] timm.optim.lion,Lion
[[autodoc]] timm.optim.lion.Lion
[[autodoc]] timm.optim.lookahead.Lookahead
[[autodoc]] timm.optim.madgrad.MADGRAD
[[autodoc]] timm.optim.nadam.Nadam
Expand Down
50 changes: 27 additions & 23 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter

from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo
from timm.optim import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler

Expand Down Expand Up @@ -294,28 +294,32 @@ def _build_params_dict_single(weight, bias, **kwargs):

@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
def test_optim_factory(optimizer):
get_optimizer_class(optimizer)

# test basic cases that don't need specific tuning via factory test
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
)
assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer)

opt_info = get_optimizer_info(optimizer)
assert isinstance(opt_info, OptimInfo)

if not opt_info.second_order: # basic tests don't support second order right now
# test basic cases that don't need specific tuning via factory test
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
)


#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
Expand Down
6 changes: 3 additions & 3 deletions timm/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP

from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
create_optimizer_v2, create_optimizer, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
26 changes: 25 additions & 1 deletion timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import torch.nn as nn
import torch.optim as optim

from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian
from .adamp import AdamP
from .adan import Adan
from .adopt import Adopt
Expand Down Expand Up @@ -78,6 +79,7 @@ class OptimInfo:
has_momentum: bool = False
has_betas: bool = False
num_betas: int = 2
second_order: bool = False
defaults: Optional[Dict[str, Any]] = None


Expand Down Expand Up @@ -540,6 +542,13 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
num_betas=3
),
OptimInfo(
name='adahessian',
opt_class=Adahessian,
description='An Adaptive Second Order Optimizer',
has_betas=True,
second_order=True,
),
OptimInfo(
name='lion',
opt_class=Lion,
Expand Down Expand Up @@ -770,6 +779,21 @@ def list_optimizers(
return default_registry.list_optimizers(filter, exclude_filters, with_description)


def get_optimizer_info(name: str) -> OptimInfo:
"""Get the OptimInfo for an optimizer.
Args:
name: Name of the optimizer
Returns:
OptimInfo configuration
Raises:
ValueError: If optimizer is not found
"""
return default_registry.get_optimizer_info(name)


def get_optimizer_class(
name: str,
bind_defaults: bool = False,
Expand Down
8 changes: 5 additions & 3 deletions timm/optim/_param_groups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from itertools import islice
from typing import Collection, Optional, Tuple
from typing import Collection, Optional

from torch import nn as nn

Expand Down Expand Up @@ -37,7 +37,7 @@ def _group(it, size):
return iter(lambda: tuple(islice(it, size)), ())


def _layer_map(model, layers_per_group=12, num_groups=None):
def auto_group_layers(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
Expand All @@ -63,6 +63,8 @@ def _in_head(n, hp):
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map

_layer_map = auto_group_layers # backward compat


def param_groups_layer_decay(
model: nn.Module,
Expand All @@ -86,7 +88,7 @@ def param_groups_layer_decay(
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
else:
# fallback
layer_map = _layer_map(model)
layer_map = auto_group_layers(model)
num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
Expand Down
7 changes: 7 additions & 0 deletions timm/optim/optim_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/

from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)

0 comments on commit 0e6da65

Please sign in to comment.