diff --git a/tests/test_optim.py b/tests/test_optim.py index dabedc1027..c4fe25e187 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -13,7 +13,7 @@ from torch.nn import Parameter from timm.scheduler import PlateauLRScheduler -from timm.optim import create_optimizer_v2 +from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class import importlib import os @@ -291,10 +291,11 @@ def _build_params_dict_single(weight, bias, **kwargs): return [dict(params=bias, **kwargs)] -#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum']) -# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts -@pytest.mark.parametrize('optimizer', ['sgd']) -def test_sgd(optimizer): +@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*'))) +def test_optim_factory(optimizer): + get_optimizer_class(optimizer) + + # test basic cases that don't need specific tuning via factory test _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) ) @@ -314,6 +315,12 @@ def test_sgd(optimizer): lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias, lr=1e-2), optimizer) ) + + +#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum']) +# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts +@pytest.mark.parametrize('optimizer', ['sgd']) +def test_sgd(optimizer): # _test_basic_cases( # lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3), # [lambda opt: StepLR(opt, gamma=0.9, step_size=10)] @@ -356,21 +363,6 @@ def test_sgd(optimizer): @pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw']) def test_adam(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) ) @@ -379,21 +371,6 @@ def test_adam(optimizer): @pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) def test_adopt(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) # FIXME rosenbrock is not passing for ADOPT # _test_rosenbrock( # lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) @@ -403,25 +380,6 @@ def test_adopt(optimizer): @pytest.mark.parametrize('optimizer', ['adabelief']) def test_adabelief(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) ) @@ -433,21 +391,6 @@ def test_adabelief(optimizer): @pytest.mark.parametrize('optimizer', ['radam', 'radabelief']) def test_rectified(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -456,25 +399,6 @@ def test_rectified(optimizer): @pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad']) def test_adaother(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) ) @@ -486,24 +410,6 @@ def test_adaother(optimizer): @pytest.mark.parametrize('optimizer', ['adafactor', 'adafactorbv']) def test_adafactor(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer) - ) _test_basic_cases( lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) ) @@ -515,25 +421,6 @@ def test_adafactor(optimizer): @pytest.mark.parametrize('optimizer', ['lamb', 'lambc']) def test_lamb(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -542,25 +429,6 @@ def test_lamb(optimizer): @pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc']) def test_lars(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -569,25 +437,6 @@ def test_lars(optimizer): @pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw']) def test_madgrad(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-2) ) @@ -596,25 +445,6 @@ def test_madgrad(optimizer): @pytest.mark.parametrize('optimizer', ['novograd']) def test_novograd(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -623,25 +453,6 @@ def test_novograd(optimizer): @pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf']) def test_rmsprop(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-2) ) @@ -650,25 +461,6 @@ def test_rmsprop(optimizer): @pytest.mark.parametrize('optimizer', ['adamp']) def test_adamp(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) ) @@ -677,25 +469,6 @@ def test_adamp(optimizer): @pytest.mark.parametrize('optimizer', ['sgdp']) def test_sgdp(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -704,25 +477,6 @@ def test_sgdp(optimizer): @pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum']) def test_lookahead_sgd(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) ) @@ -730,25 +484,6 @@ def test_lookahead_sgd(optimizer): @pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam']) def test_lookahead_adam(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) ) @@ -756,25 +491,6 @@ def test_lookahead_adam(optimizer): @pytest.mark.parametrize('optimizer', ['lookahead_radam']) def test_lookahead_radam(optimizer): - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=3e-3), optimizer) - ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-4) ) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 85b2ca2aa5..c3c533cc7f 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -17,4 +17,5 @@ from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs \ No newline at end of file +from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \ + create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry \ No newline at end of file diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py new file mode 100644 index 0000000000..c38363f892 --- /dev/null +++ b/timm/optim/_optim_factory.py @@ -0,0 +1,798 @@ +""" Optimizer Factory w/ custom Weight Decay & Layer Decay support + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator +from fnmatch import fnmatch +import importlib + +import torch +import torch.nn as nn +import torch.optim as optim + +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters +from .adabelief import AdaBelief +from .adafactor import Adafactor +from .adafactor_bv import AdafactorBigVision +from .adamp import AdamP +from .adan import Adan +from .adopt import Adopt +from .lamb import Lamb +from .lars import Lars +from .lion import Lion +from .lookahead import Lookahead +from .madgrad import MADGRAD +from .nadam import Nadam +from .nadamw import NAdamW +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .sgdw import SGDW + +_logger = logging.getLogger(__name__) + +# Type variables +T = TypeVar('T') +Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]] +OptimType = TypeVar('OptimType', bound='optim.Optimizer') + + +def _import_class(class_string: str) -> Type: + """Dynamically import a class from a string.""" + try: + module_name, class_name = class_string.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not import {class_string}: {e}") + + +class OptimizerCallable(Protocol): + """Protocol for optimizer constructor signatures.""" + + def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ... + + +@dataclass(frozen=True) +class OptimInfo: + """Immutable configuration for an optimizer. + + Attributes: + name: Unique identifier for the optimizer + opt_class: The optimizer class + description: Brief description of the optimizer's characteristics and behavior + has_eps: Whether the optimizer accepts epsilon parameter + has_momentum: Whether the optimizer accepts momentum parameter + has_betas: Whether the optimizer accepts a tuple of beta parameters + num_betas: number of betas in tuple (valid IFF has_betas = True) + defaults: Optional default parameters for the optimizer + """ + name: str + opt_class: Union[str, Type[optim.Optimizer]] + description: str = '' + has_eps: bool = True + has_momentum: bool = False + has_betas: bool = False + num_betas: int = 2 + defaults: Optional[Dict[str, Any]] = None + + +class OptimizerRegistry: + """Registry managing optimizer configurations and instantiation. + + This class provides a central registry for optimizer configurations and handles + their instantiation with appropriate parameter groups and settings. + """ + + def __init__(self) -> None: + self._optimizers: Dict[str, OptimInfo] = {} + self._foreach_defaults: Set[str] = {'lion'} + + def register(self, info: OptimInfo) -> None: + """Register an optimizer configuration. + + Args: + info: The OptimInfo configuration containing name, type and description + """ + name = info.name.lower() + if name in self._optimizers: + _logger.warning(f'Optimizer {name} already registered, overwriting') + self._optimizers[name] = info + + def register_alias(self, alias: str, target: str) -> None: + """Register an alias for an existing optimizer. + + Args: + alias: The alias name + target: The target optimizer name + + Raises: + KeyError: If target optimizer doesn't exist + """ + target = target.lower() + if target not in self._optimizers: + raise KeyError(f'Cannot create alias for non-existent optimizer {target}') + self._optimizers[alias.lower()] = self._optimizers[target] + + def register_foreach_default(self, name: str) -> None: + """Register an optimizer as defaulting to foreach=True.""" + self._foreach_defaults.add(name.lower()) + + def list_optimizers( + self, + filter: str = '', + exclude_filters: Optional[List[str]] = None, + with_description: bool = False + ) -> List[Union[str, Tuple[str, str]]]: + """List available optimizer names, optionally filtered. + + Args: + filter: Wildcard style filter string (e.g., 'adam*') + exclude_filters: Optional list of wildcard patterns to exclude + with_description: If True, return tuples of (name, description) + + Returns: + List of either optimizer names or (name, description) tuples + """ + names = sorted(self._optimizers.keys()) + + if filter: + names = [n for n in names if fnmatch(n, filter)] + + if exclude_filters: + for exclude_filter in exclude_filters: + names = [n for n in names if not fnmatch(n, exclude_filter)] + + if with_description: + return [(name, self._optimizers[name].description) for name in names] + return names + + def get_optimizer_info(self, name: str) -> OptimInfo: + """Get the OptimInfo for an optimizer. + + Args: + name: Name of the optimizer + + Returns: + OptimInfo configuration + + Raises: + ValueError: If optimizer is not found + """ + name = name.lower() + if name not in self._optimizers: + raise ValueError(f'Optimizer {name} not found in registry') + return self._optimizers[name] + + def get_optimizer_class( + self, + name_or_info: Union[str, OptimInfo], + bind_defaults: bool = True, + ) -> Union[Type[optim.Optimizer], OptimizerCallable]: + """Get the optimizer class with any default arguments applied. + + This allows direct instantiation of optimizers with their default configs + without going through the full factory. + + Args: + name_or_info: Name of the optimizer + bind_defaults: Bind default arguments to optimizer class via `partial` before returning + + Returns: + Optimizer class or partial with defaults applied + + Raises: + ValueError: If optimizer not found + """ + if isinstance(name_or_info, str): + opt_info = self.get_optimizer_info(name_or_info) + else: + assert isinstance(name_or_info, OptimInfo) + opt_info = name_or_info + + if isinstance(opt_info.opt_class, str): + # Special handling for APEX and BNB optimizers + if opt_info.opt_class.startswith('apex.'): + assert torch.cuda.is_available(), 'CUDA required for APEX optimizers' + try: + opt_class = _import_class(opt_info.opt_class) + except ImportError as e: + raise ImportError('APEX optimizers require apex to be installed') from e + elif opt_info.opt_class.startswith('bitsandbytes.'): + assert torch.cuda.is_available(), 'CUDA required for bitsandbytes optimizers' + try: + opt_class = _import_class(opt_info.opt_class) + except ImportError as e: + raise ImportError('bitsandbytes optimizers require bitsandbytes to be installed') from e + else: + opt_class = _import_class(opt_info.opt_class) + else: + opt_class = opt_info.opt_class + + # Return class or partial with defaults + if bind_defaults and opt_info.defaults: + opt_class = partial(opt_class, **opt_info.defaults) + + return opt_class + + def create_optimizer( + self, + model_or_params: Union[nn.Module, Params], + opt: str, + lr: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + foreach: Optional[bool] = None, + weight_decay_exclude_1d: bool = True, + layer_decay: Optional[float] = None, + param_group_fn: Optional[Callable[[nn.Module], Params]] = None, + **kwargs: Any, + ) -> optim.Optimizer: + """Create an optimizer instance. + + Args: + model_or_params: Model or parameters to optimize + opt: Name of optimizer to create + lr: Learning rate + weight_decay: Weight decay factor + momentum: Momentum factor for applicable optimizers + foreach: Enable/disable foreach operation + weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine) + layer_decay: Layer-wise learning rate decay + param_group_fn: Optional custom parameter grouping function + **kwargs: Additional optimizer-specific arguments + + Returns: + Configured optimizer instance + + Raises: + ValueError: If optimizer not found or configuration invalid + """ + + # Get parameters to optimize + if isinstance(model_or_params, nn.Module): + # Extract parameters from a nn.Module, build param groups w/ weight-decay and/or layer-decay applied + no_weight_decay = getattr(model_or_params, 'no_weight_decay', lambda: set())() + + if param_group_fn: + # run custom fn to generate param groups from nn.Module + parameters = param_group_fn(model_or_params) + elif layer_decay is not None: + parameters = param_groups_layer_decay( + model_or_params, + weight_decay=weight_decay, + layer_decay=layer_decay, + no_weight_decay_list=no_weight_decay, + weight_decay_exclude_1d=weight_decay_exclude_1d, + ) + weight_decay = 0. + elif weight_decay and weight_decay_exclude_1d: + parameters = param_groups_weight_decay( + model_or_params, + weight_decay=weight_decay, + no_weight_decay_list=no_weight_decay, + ) + weight_decay = 0. + else: + parameters = model_or_params.parameters() + else: + # pass parameters / parameter groups through to optimizer + parameters = model_or_params + + # Parse optimizer name + opt_split = opt.lower().split('_') + opt_name = opt_split[-1] + use_lookahead = opt_split[0] == 'lookahead' if len(opt_split) > 1 else False + + opt_info = self.get_optimizer_info(opt_name) + + # Build optimizer arguments + opt_args: Dict[str, Any] = {'weight_decay': weight_decay, **kwargs} + + # Add LR to args, if None optimizer default is used, some optimizers manage LR internally if None. + if lr is not None: + opt_args['lr'] = lr + + # Apply optimizer-specific settings + if opt_info.defaults: + for k, v in opt_info.defaults.items(): + opt_args.setdefault(k, v) + + # timm has always defaulted momentum to 0.9 if optimizer supports momentum, keep for backward compat. + if opt_info.has_momentum: + opt_args.setdefault('momentum', momentum) + + # Remove commonly used kwargs that aren't always supported + if not opt_info.has_eps: + opt_args.pop('eps', None) + if not opt_info.has_betas: + opt_args.pop('betas', None) + + if foreach is not None: + # Explicitly activate or deactivate multi-tensor foreach impl. + # Not all optimizers support this, and those that do usually default to using + # multi-tensor impl if foreach is left as default 'None' and can be enabled. + opt_args.setdefault('foreach', foreach) + + # Create optimizer + opt_class = self.get_optimizer_class(opt_info, bind_defaults=False) + optimizer = opt_class(parameters, **opt_args) + + # Apply Lookahead if requested + if use_lookahead: + optimizer = Lookahead(optimizer) + + return optimizer + + +def _register_sgd_variants(registry: OptimizerRegistry) -> None: + """Register SGD-based optimizers""" + sgd_optimizers = [ + OptimInfo( + name='sgd', + opt_class=optim.SGD, + description='Stochastic Gradient Descent with Nesterov momentum (default)', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='momentum', + opt_class=optim.SGD, + description='Stochastic Gradient Descent with classical momentum', + has_eps=False, + has_momentum=True, + defaults={'nesterov': False} + ), + OptimInfo( + name='sgdp', + opt_class=SGDP, + description='SGD with built-in projection to unit norm sphere', + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='sgdw', + opt_class=SGDW, + description='SGD with decoupled weight decay and Nesterov momentum', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + ] + for opt in sgd_optimizers: + registry.register(opt) + + +def _register_adam_variants(registry: OptimizerRegistry) -> None: + """Register Adam-based optimizers""" + adam_optimizers = [ + OptimInfo( + name='adam', + opt_class=optim.Adam, + description='torch.optim Adam (Adaptive Moment Estimation)', + has_betas=True + ), + OptimInfo( + name='adamw', + opt_class=optim.AdamW, + description='torch.optim Adam with decoupled weight decay regularization', + has_betas=True + ), + OptimInfo( + name='adamp', + opt_class=AdamP, + description='Adam with built-in projection to unit norm sphere', + has_betas=True, + defaults={'wd_ratio': 0.01, 'nesterov': True} + ), + OptimInfo( + name='nadam', + opt_class=Nadam, + description='Adam with Nesterov momentum', + has_betas=True + ), + OptimInfo( + name='nadamw', + opt_class=NAdamW, + description='Adam with Nesterov momentum and decoupled weight decay', + has_betas=True + ), + OptimInfo( + name='radam', + opt_class=RAdam, + description='Rectified Adam with variance adaptation', + has_betas=True + ), + OptimInfo( + name='adamax', + opt_class=optim.Adamax, + description='torch.optim Adamax, Adam with infinity norm for more stable updates', + has_betas=True + ), + OptimInfo( + name='adafactor', + opt_class=Adafactor, + description='Memory-efficient implementation of Adam with factored gradients', + ), + OptimInfo( + name='adafactorbv', + opt_class=AdafactorBigVision, + description='Big Vision variant of Adafactor with factored gradients, half precision momentum.', + ), + OptimInfo( + name='adopt', + opt_class=Adopt, + description='Memory-efficient implementation of Adam with factored gradients', + ), + OptimInfo( + name='adoptw', + opt_class=Adopt, + description='Memory-efficient implementation of Adam with factored gradients', + defaults={'decoupled': True} + ), + ] + for opt in adam_optimizers: + registry.register(opt) + + +def _register_lamb_lars(registry: OptimizerRegistry) -> None: + """Register LAMB and LARS variants""" + lamb_lars_optimizers = [ + OptimInfo( + name='lamb', + opt_class=Lamb, + description='Layer-wise Adaptive Moments for batch optimization', + has_betas=True + ), + OptimInfo( + name='lambc', + opt_class=Lamb, + description='LAMB with trust ratio clipping for stability', + has_betas=True, + defaults={'trust_clip': True} + ), + OptimInfo( + name='lars', + opt_class=Lars, + description='Layer-wise Adaptive Rate Scaling', + has_momentum=True + ), + OptimInfo( + name='larc', + opt_class=Lars, + description='LARS with trust ratio clipping for stability', + has_momentum=True, + defaults={'trust_clip': True} + ), + OptimInfo( + name='nlars', + opt_class=Lars, + description='LARS with Nesterov momentum', + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='nlarc', + opt_class=Lars, + description='LARS with Nesterov momentum & trust ratio clipping', + has_momentum=True, + defaults={'nesterov': True, 'trust_clip': True} + ), + ] + for opt in lamb_lars_optimizers: + registry.register(opt) + + +def _register_other_optimizers(registry: OptimizerRegistry) -> None: + """Register miscellaneous optimizers""" + other_optimizers = [ + OptimInfo( + name='adabelief', + opt_class=AdaBelief, + description='Adapts learning rate based on gradient prediction error', + has_betas=True, + defaults={'rectify': False} + ), + OptimInfo( + name='radabelief', + opt_class=AdaBelief, + description='Rectified AdaBelief with variance adaptation', + has_betas=True, + defaults={'rectify': True} + ), + OptimInfo( + name='adadelta', + opt_class=optim.Adadelta, + description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients' + ), + OptimInfo( + name='adagrad', + opt_class=optim.Adagrad, + description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients', + defaults={'eps': 1e-8} + ), + OptimInfo( + name='adan', + opt_class=Adan, + description='Adaptive Nesterov Momentum Algorithm', + defaults={'no_prox': False}, + has_betas=True, + num_betas=3 + ), + OptimInfo( + name='adanw', + opt_class=Adan, + description='Adaptive Nesterov Momentum with decoupled weight decay', + defaults={'no_prox': True}, + has_betas=True, + num_betas=3 + ), + OptimInfo( + name='lion', + opt_class=Lion, + description='Evolved Sign Momentum optimizer for improved convergence', + has_eps=False, + has_betas=True + ), + OptimInfo( + name='madgrad', + opt_class=MADGRAD, + description='Momentum-based Adaptive gradient method', + has_momentum=True + ), + OptimInfo( + name='madgradw', + opt_class=MADGRAD, + description='MADGRAD with decoupled weight decay', + has_momentum=True, + defaults={'decoupled_decay': True} + ), + OptimInfo( + name='novograd', + opt_class=NvNovoGrad, + description='Normalized Adam with L2 norm gradient normalization', + has_betas=True + ), + OptimInfo( + name='rmsprop', + opt_class=optim.RMSprop, + description='torch.optim RMSprop, Root Mean Square Propagation', + has_momentum=True, + defaults={'alpha': 0.9} + ), + OptimInfo( + name='rmsproptf', + opt_class=RMSpropTF, + description='TensorFlow-style RMSprop implementation, Root Mean Square Propagation', + has_momentum=True, + defaults={'alpha': 0.9} + ), + ] + for opt in other_optimizers: + registry.register(opt) + registry.register_foreach_default('lion') + + +def _register_apex_optimizers(registry: OptimizerRegistry) -> None: + """Register APEX optimizers (lazy import)""" + apex_optimizers = [ + OptimInfo( + name='fusedsgd', + opt_class='apex.optimizers.FusedSGD', + description='NVIDIA APEX fused SGD implementation for faster training', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='fusedadam', + opt_class='apex.optimizers.FusedAdam', + description='NVIDIA APEX fused Adam implementation', + has_betas=True, + defaults={'adam_w_mode': False} + ), + OptimInfo( + name='fusedadamw', + opt_class='apex.optimizers.FusedAdam', + description='NVIDIA APEX fused AdamW implementation', + has_betas=True, + defaults={'adam_w_mode': True} + ), + OptimInfo( + name='fusedlamb', + opt_class='apex.optimizers.FusedLAMB', + description='NVIDIA APEX fused LAMB implementation', + has_betas=True + ), + OptimInfo( + name='fusednovograd', + opt_class='apex.optimizers.FusedNovoGrad', + description='NVIDIA APEX fused NovoGrad implementation', + has_betas=True, + defaults={'betas': (0.95, 0.98)} + ), + ] + for opt in apex_optimizers: + registry.register(opt) + + +def _register_bnb_optimizers(registry: OptimizerRegistry) -> None: + """Register bitsandbytes optimizers (lazy import)""" + bnb_optimizers = [ + OptimInfo( + name='bnbsgd', + opt_class='bitsandbytes.optim.SGD', + description='bitsandbytes SGD', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='bnbsgd8bit', + opt_class='bitsandbytes.optim.SGD8bit', + description='bitsandbytes 8-bit SGD with dynamic quantization', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True} + ), + OptimInfo( + name='bnbadam', + opt_class='bitsandbytes.optim.Adam', + description='bitsandbytes Adam', + has_betas=True + ), + OptimInfo( + name='bnbadam8bit', + opt_class='bitsandbytes.optim.Adam', + description='bitsandbytes 8-bit Adam with dynamic quantization', + has_betas=True + ), + OptimInfo( + name='bnbadamw', + opt_class='bitsandbytes.optim.AdamW', + description='bitsandbytes AdamW', + has_betas=True + ), + OptimInfo( + name='bnbadamw8bit', + opt_class='bitsandbytes.optim.AdamW', + description='bitsandbytes 8-bit AdamW with dynamic quantization', + has_betas=True + ), + OptimInfo( + 'bnblion', + 'bitsandbytes.optim.Lion', + description='bitsandbytes Lion', + has_betas=True + ), + OptimInfo( + 'bnblion8bit', + 'bitsandbytes.optim.Lion8bit', + description='bitsandbytes 8-bit Lion with dynamic quantization', + has_betas=True + ), + OptimInfo( + 'bnbademamix', + 'bitsandbytes.optim.AdEMAMix', + description='bitsandbytes AdEMAMix', + has_betas=True, + num_betas=3, + ), + OptimInfo( + 'bnbademamix8bit', + 'bitsandbytes.optim.AdEMAMix8bit', + description='bitsandbytes 8-bit AdEMAMix with dynamic quantization', + has_betas=True, + num_betas=3, + ), + ] + for opt in bnb_optimizers: + registry.register(opt) + + +default_registry = OptimizerRegistry() + +def _register_default_optimizers() -> None: + """Register all default optimizers to the global registry.""" + # Register all optimizer groups + _register_sgd_variants(default_registry) + _register_adam_variants(default_registry) + _register_lamb_lars(default_registry) + _register_other_optimizers(default_registry) + _register_apex_optimizers(default_registry) + _register_bnb_optimizers(default_registry) + + # Register aliases + default_registry.register_alias('nesterov', 'sgd') + default_registry.register_alias('nesterovw', 'sgdw') + + +# Initialize default registry +_register_default_optimizers() + +# Public API + +def list_optimizers( + filter: str = '', + exclude_filters: Optional[List[str]] = None, + with_description: bool = False, +) -> List[Union[str, Tuple[str, str]]]: + """List available optimizer names, optionally filtered. + """ + return default_registry.list_optimizers(filter, exclude_filters, with_description) + + +def get_optimizer_class( + name: str, + bind_defaults: bool = False, +) -> Union[Type[optim.Optimizer], OptimizerCallable]: + """Get optimizer class by name with any defaults applied. + """ + return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults) + + +def create_optimizer_v2( + model_or_params: Union[nn.Module, Params], + opt: str = 'sgd', + lr: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + foreach: Optional[bool] = None, + filter_bias_and_bn: bool = True, + layer_decay: Optional[float] = None, + param_group_fn: Optional[Callable[[nn.Module], Params]] = None, + **kwargs: Any, +) -> optim.Optimizer: + """Create an optimizer instance using the default registry.""" + return default_registry.create_optimizer( + model_or_params, + opt=opt, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + foreach=foreach, + weight_decay_exclude_1d=filter_bias_and_bn, + layer_decay=layer_decay, + param_group_fn=param_group_fn, + **kwargs + ) + + +def optimizer_kwargs(cfg): + """ cfg/argparse to kwargs helper + Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. + """ + kwargs = dict( + opt=cfg.opt, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + momentum=cfg.momentum, + ) + if getattr(cfg, 'opt_eps', None) is not None: + kwargs['eps'] = cfg.opt_eps + if getattr(cfg, 'opt_betas', None) is not None: + kwargs['betas'] = cfg.opt_betas + if getattr(cfg, 'layer_decay', None) is not None: + kwargs['layer_decay'] = cfg.layer_decay + if getattr(cfg, 'opt_args', None) is not None: + kwargs.update(cfg.opt_args) + if getattr(cfg, 'opt_foreach', None) is not None: + kwargs['foreach'] = cfg.opt_foreach + return kwargs + + +def create_optimizer(args, model, filter_bias_and_bn=True): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + return create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + filter_bias_and_bn=filter_bias_and_bn, + ) + diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py new file mode 100644 index 0000000000..ef9faacf96 --- /dev/null +++ b/timm/optim/_param_groups.py @@ -0,0 +1,129 @@ +import logging +from itertools import islice +from typing import Collection, Optional, Tuple + +from torch import nn as nn + +from timm.models import group_parameters + + +_logger = logging.getLogger(__name__) + + +def param_groups_weight_decay( + model: nn.Module, + weight_decay: float = 1e-5, + no_weight_decay_list: Collection[str] = (), +): + no_weight_decay_list = set(no_weight_decay_list) + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + no_decay.append(param) + else: + decay.append(param) + + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def _group(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def _layer_map(model, layers_per_group=12, num_groups=None): + def _in_head(n, hp): + if not hp: + return True + elif isinstance(hp, (tuple, list)): + return any([n.startswith(hpi) for hpi in hp]) + else: + return n.startswith(hp) + + head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None) + names_trunk = [] + names_head = [] + for n, _ in model.named_parameters(): + names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) + + # group non-head layers + num_trunk_layers = len(names_trunk) + if num_groups is not None: + layers_per_group = -(num_trunk_layers // -num_groups) + names_trunk = list(_group(names_trunk, layers_per_group)) + + num_trunk_groups = len(names_trunk) + layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} + layer_map.update({n: num_trunk_groups for n in names_head}) + return layer_map + + +def param_groups_layer_decay( + model: nn.Module, + weight_decay: float = 0.05, + no_weight_decay_list: Collection[str] = (), + weight_decay_exclude_1d: bool = True, + layer_decay: float = .75, + end_layer_decay: Optional[float] = None, + verbose: bool = False, +): + """ + Parameter groups for layer-wise lr decay & weight decay + Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + no_weight_decay_list = set(no_weight_decay_list) + param_group_names = {} # NOTE for debugging + param_groups = {} + + if hasattr(model, 'group_matcher'): + # FIXME interface needs more work + layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) + else: + # fallback + layer_map = _layer_map(model) + num_layers = max(layer_map.values()) + 1 + layer_max = num_layers - 1 + layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = layer_map.get(name, layer_max) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_groups: + this_scale = layer_scales[layer_id] + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "param_names": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["param_names"].append(name) + param_groups[group_name]["params"].append(param) + + if verbose: + import json + _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) diff --git a/timm/optim/adabelief.py b/timm/optim/adabelief.py index 951d715cc0..dd2abb6ba9 100644 --- a/timm/optim/adabelief.py +++ b/timm/optim/adabelief.py @@ -40,9 +40,18 @@ class AdaBelief(Optimizer): """ def __init__( - self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, - decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True): - + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-16, + weight_decay=0, + amsgrad=False, + decoupled_decay=True, + fixed_decay=False, + rectify=True, + degenerated_to_sgd=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -58,9 +67,17 @@ def __init__( param['buffer'] = [[None, None, None] for _ in range(10)] defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, - degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify, - fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)]) + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + degenerated_to_sgd=degenerated_to_sgd, + decoupled_decay=decoupled_decay, + rectify=rectify, + fixed_decay=fixed_decay, + buffer=[[None, None, None] for _ in range(10)] + ) super(AdaBelief, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py index 4cefad1815..01c25ff2fb 100644 --- a/timm/optim/adafactor.py +++ b/timm/optim/adafactor.py @@ -16,6 +16,7 @@ class Adafactor(torch.optim.Optimizer): """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` (see https://arxiv.org/abs/1804.04235) diff --git a/timm/optim/adahessian.py b/timm/optim/adahessian.py index 985c67ca68..9067cc66cf 100644 --- a/timm/optim/adahessian.py +++ b/timm/optim/adahessian.py @@ -23,8 +23,18 @@ class Adahessian(torch.optim.Optimizer): n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) """ - def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, - hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + def __init__( + self, + params, + lr=0.1, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + hessian_power=1.0, + update_each=1, + n_samples=1, + avg_conv_kernel=False, + ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: @@ -44,7 +54,13 @@ def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0. self.seed = 2147483647 self.generator = torch.Generator().manual_seed(self.seed) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + hessian_power=hessian_power, + ) super(Adahessian, self).__init__(params, defaults) for p in self.get_params(): diff --git a/timm/optim/adamp.py b/timm/optim/adamp.py index ee187633ab..5a9ac3395d 100644 --- a/timm/optim/adamp.py +++ b/timm/optim/adamp.py @@ -41,11 +41,26 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): class AdamP(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + delta=0.1, + wd_ratio=0.1, + nesterov=False, + ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, - delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + delta=delta, + wd_ratio=wd_ratio, + nesterov=nesterov, + ) super(AdamP, self).__init__(params, defaults) @torch.no_grad() diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py index 66478bc6ef..b755a57ca8 100644 --- a/timm/optim/adamw.py +++ b/timm/optim/adamw.py @@ -36,8 +36,16 @@ class AdamW(Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + ): + # NOTE: deprecated in favour of builtin torch.optim.AdamW if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -46,8 +54,13 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ) super(AdamW, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/lion.py b/timm/optim/lion.py index 4d8086424d..3bcb273cac 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -137,7 +137,7 @@ def lion( """ if foreach is None: # Placeholder for more complex foreach logic to be added when value is not set - foreach = False + foreach = True if foreach and torch.jit.is_scripting(): raise RuntimeError('torch.jit.script not supported with foreach optimizers') diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py index a76713bf27..8e449dce3d 100644 --- a/timm/optim/madgrad.py +++ b/timm/optim/madgrad.py @@ -71,7 +71,12 @@ def __init__( raise ValueError(f"Eps must be non-negative") defaults = dict( - lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) + lr=lr, + eps=eps, + momentum=momentum, + weight_decay=weight_decay, + decoupled_decay=decoupled_decay, + ) super().__init__(params, defaults) @property diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py index 4e911420ef..892262cfac 100644 --- a/timm/optim/nadam.py +++ b/timm/optim/nadam.py @@ -27,8 +27,15 @@ class Nadam(Optimizer): NOTE: Has potential issues but does work well on some problems. """ - def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, schedule_decay=4e-3): + def __init__( + self, + params, + lr=2e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + schedule_decay=4e-3, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) defaults = dict( diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py index fda3f4a620..068e5aa2c1 100644 --- a/timm/optim/nvnovograd.py +++ b/timm/optim/nvnovograd.py @@ -29,8 +29,16 @@ class NvNovoGrad(Optimizer): (default: False) """ - def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, - weight_decay=0, grad_averaging=False, amsgrad=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.95, 0.98), + eps=1e-8, + weight_decay=0, + grad_averaging=False, + amsgrad=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -39,10 +47,14 @@ def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - amsgrad=amsgrad) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad, + ) super(NvNovoGrad, self).__init__(params, defaults) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py deleted file mode 100644 index 31b40bf680..0000000000 --- a/timm/optim/optim_factory.py +++ /dev/null @@ -1,431 +0,0 @@ -""" Optimizer Factory w/ Custom Weight Decay -Hacked together by / Copyright 2021 Ross Wightman -""" -import logging -from itertools import islice -from typing import Optional, Callable, Tuple - -import torch -import torch.nn as nn -import torch.optim as optim - -from timm.models import group_parameters -from . import AdafactorBigVision - -from .adabelief import AdaBelief -from .adafactor import Adafactor -from .adahessian import Adahessian -from .adamp import AdamP -from .adan import Adan -from .adopt import Adopt -from .lamb import Lamb -from .lars import Lars -from .lion import Lion -from .lookahead import Lookahead -from .madgrad import MADGRAD -from .nadam import Nadam -from .nadamw import NAdamW -from .nvnovograd import NvNovoGrad -from .radam import RAdam -from .rmsprop_tf import RMSpropTF -from .sgdp import SGDP -from .sgdw import SGDW - - -_logger = logging.getLogger(__name__) - - -# optimizers to default to multi-tensor -_DEFAULT_FOREACH = { - 'lion', -} - - -def param_groups_weight_decay( - model: nn.Module, - weight_decay=1e-5, - no_weight_decay_list=() -): - no_weight_decay_list = set(no_weight_decay_list) - decay = [] - no_decay = [] - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: - no_decay.append(param) - else: - decay.append(param) - - return [ - {'params': no_decay, 'weight_decay': 0.}, - {'params': decay, 'weight_decay': weight_decay}] - - -def _group(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - - -def _layer_map(model, layers_per_group=12, num_groups=None): - def _in_head(n, hp): - if not hp: - return True - elif isinstance(hp, (tuple, list)): - return any([n.startswith(hpi) for hpi in hp]) - else: - return n.startswith(hp) - - head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None) - names_trunk = [] - names_head = [] - for n, _ in model.named_parameters(): - names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) - - # group non-head layers - num_trunk_layers = len(names_trunk) - if num_groups is not None: - layers_per_group = -(num_trunk_layers // -num_groups) - names_trunk = list(_group(names_trunk, layers_per_group)) - - num_trunk_groups = len(names_trunk) - layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} - layer_map.update({n: num_trunk_groups for n in names_head}) - return layer_map - - -def param_groups_layer_decay( - model: nn.Module, - weight_decay: float = 0.05, - no_weight_decay_list: Tuple[str] = (), - layer_decay: float = .75, - end_layer_decay: Optional[float] = None, - verbose: bool = False, -): - """ - Parameter groups for layer-wise lr decay & weight decay - Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 - """ - no_weight_decay_list = set(no_weight_decay_list) - param_group_names = {} # NOTE for debugging - param_groups = {} - - if hasattr(model, 'group_matcher'): - # FIXME interface needs more work - layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) - else: - # fallback - layer_map = _layer_map(model) - num_layers = max(layer_map.values()) + 1 - layer_max = num_layers - 1 - layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) - - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - # no decay: all 1D parameters and model specific ones - if param.ndim == 1 or name in no_weight_decay_list: - g_decay = "no_decay" - this_decay = 0. - else: - g_decay = "decay" - this_decay = weight_decay - - layer_id = layer_map.get(name, layer_max) - group_name = "layer_%d_%s" % (layer_id, g_decay) - - if group_name not in param_groups: - this_scale = layer_scales[layer_id] - param_group_names[group_name] = { - "lr_scale": this_scale, - "weight_decay": this_decay, - "param_names": [], - } - param_groups[group_name] = { - "lr_scale": this_scale, - "weight_decay": this_decay, - "params": [], - } - - param_group_names[group_name]["param_names"].append(name) - param_groups[group_name]["params"].append(param) - - if verbose: - import json - _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) - - return list(param_groups.values()) - - -def optimizer_kwargs(cfg): - """ cfg/argparse to kwargs helper - Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. - """ - kwargs = dict( - opt=cfg.opt, - lr=cfg.lr, - weight_decay=cfg.weight_decay, - momentum=cfg.momentum, - ) - if getattr(cfg, 'opt_eps', None) is not None: - kwargs['eps'] = cfg.opt_eps - if getattr(cfg, 'opt_betas', None) is not None: - kwargs['betas'] = cfg.opt_betas - if getattr(cfg, 'layer_decay', None) is not None: - kwargs['layer_decay'] = cfg.layer_decay - if getattr(cfg, 'opt_args', None) is not None: - kwargs.update(cfg.opt_args) - if getattr(cfg, 'opt_foreach', None) is not None: - kwargs['foreach'] = cfg.opt_foreach - return kwargs - - -def create_optimizer(args, model, filter_bias_and_bn=True): - """ Legacy optimizer factory for backwards compatibility. - NOTE: Use create_optimizer_v2 for new code. - """ - return create_optimizer_v2( - model, - **optimizer_kwargs(cfg=args), - filter_bias_and_bn=filter_bias_and_bn, - ) - - -def create_optimizer_v2( - model_or_params, - opt: str = 'sgd', - lr: Optional[float] = None, - weight_decay: float = 0., - momentum: float = 0.9, - foreach: Optional[bool] = None, - filter_bias_and_bn: bool = True, - layer_decay: Optional[float] = None, - param_group_fn: Optional[Callable] = None, - **kwargs, -): - """ Create an optimizer. - - TODO currently the model is passed in and all parameters are selected for optimization. - For more general use an interface that allows selection of parameters to optimize and lr groups, one of: - * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion - * expose the parameters interface and leave it up to caller - - Args: - model_or_params (nn.Module): model containing parameters to optimize - opt: name of optimizer to create - lr: initial learning rate - weight_decay: weight decay to apply in optimizer - momentum: momentum for momentum based optimizers (others may use betas via kwargs) - foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None - filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay - **kwargs: extra optimizer specific kwargs to pass through - - Returns: - Optimizer - """ - if isinstance(model_or_params, nn.Module): - # a model was passed in, extract parameters and add weight decays to appropriate layers - no_weight_decay = {} - if hasattr(model_or_params, 'no_weight_decay'): - no_weight_decay = model_or_params.no_weight_decay() - - if param_group_fn: - parameters = param_group_fn(model_or_params) - elif layer_decay is not None: - parameters = param_groups_layer_decay( - model_or_params, - weight_decay=weight_decay, - layer_decay=layer_decay, - no_weight_decay_list=no_weight_decay, - ) - weight_decay = 0. - elif weight_decay and filter_bias_and_bn: - parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) - weight_decay = 0. - else: - parameters = model_or_params.parameters() - else: - # iterable of parameters or param groups passed in - parameters = model_or_params - - opt_lower = opt.lower() - opt_split = opt_lower.split('_') - opt_lower = opt_split[-1] - - if opt_lower.startswith('fused'): - try: - from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD - has_apex = True - except ImportError: - has_apex = False - assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' - - if opt_lower.startswith('bnb'): - try: - import bitsandbytes as bnb - has_bnb = True - except ImportError: - has_bnb = False - assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers' - - opt_args = dict(weight_decay=weight_decay, **kwargs) - - if lr is not None: - opt_args.setdefault('lr', lr) - - if foreach is None: - if opt in _DEFAULT_FOREACH: - opt_args.setdefault('foreach', True) - else: - opt_args['foreach'] = foreach - - # basic SGD & related - if opt_lower == 'sgd' or opt_lower == 'nesterov': - # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons - opt_args.pop('eps', None) - optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'momentum': - opt_args.pop('eps', None) - optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) - elif opt_lower == 'sgdp': - optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'sgdw' or opt_lower == 'nesterovw': - # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons - opt_args.pop('eps', None) - optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'momentumw': - opt_args.pop('eps', None) - optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args) - - # adaptive - elif opt_lower == 'adam': - optimizer = optim.Adam(parameters, **opt_args) - elif opt_lower == 'adamw': - optimizer = optim.AdamW(parameters, **opt_args) - elif opt_lower == 'adamp': - optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) - elif opt_lower == 'nadam': - try: - # NOTE PyTorch >= 1.10 should have native NAdam - optimizer = optim.Nadam(parameters, **opt_args) - except AttributeError: - optimizer = Nadam(parameters, **opt_args) - elif opt_lower == 'nadamw': - optimizer = NAdamW(parameters, **opt_args) - elif opt_lower == 'radam': - optimizer = RAdam(parameters, **opt_args) - elif opt_lower == 'adamax': - optimizer = optim.Adamax(parameters, **opt_args) - elif opt_lower == 'adabelief': - optimizer = AdaBelief(parameters, rectify=False, **opt_args) - elif opt_lower == 'radabelief': - optimizer = AdaBelief(parameters, rectify=True, **opt_args) - elif opt_lower == 'adadelta': - optimizer = optim.Adadelta(parameters, **opt_args) - elif opt_lower == 'adagrad': - opt_args.setdefault('eps', 1e-8) - optimizer = optim.Adagrad(parameters, **opt_args) - elif opt_lower == 'adafactor': - optimizer = Adafactor(parameters, **opt_args) - elif opt_lower == 'adanp': - optimizer = Adan(parameters, no_prox=False, **opt_args) - elif opt_lower == 'adanw': - optimizer = Adan(parameters, no_prox=True, **opt_args) - elif opt_lower == 'lamb': - optimizer = Lamb(parameters, **opt_args) - elif opt_lower == 'lambc': - optimizer = Lamb(parameters, trust_clip=True, **opt_args) - elif opt_lower == 'larc': - optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) - elif opt_lower == 'lars': - optimizer = Lars(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'nlarc': - optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) - elif opt_lower == 'nlars': - optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'madgrad': - optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'madgradw': - optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) - elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': - optimizer = NvNovoGrad(parameters, **opt_args) - elif opt_lower == 'rmsprop': - optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) - elif opt_lower == 'rmsproptf': - optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) - elif opt_lower == 'lion': - opt_args.pop('eps', None) - optimizer = Lion(parameters, **opt_args) - elif opt_lower == 'adafactorbv': - optimizer = AdafactorBigVision(parameters, **opt_args) - elif opt_lower == 'adopt': - optimizer = Adopt(parameters, **opt_args) - elif opt_lower == 'adoptw': - optimizer = Adopt(parameters, decoupled=True, **opt_args) - - # second order - elif opt_lower == 'adahessian': - optimizer = Adahessian(parameters, **opt_args) - - # NVIDIA fused optimizers, require APEX to be installed - elif opt_lower == 'fusedsgd': - opt_args.pop('eps', None) - optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'fusedmomentum': - opt_args.pop('eps', None) - optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) - elif opt_lower == 'fusedadam': - optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) - elif opt_lower == 'fusedadamw': - optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) - elif opt_lower == 'fusedlamb': - optimizer = FusedLAMB(parameters, **opt_args) - elif opt_lower == 'fusednovograd': - opt_args.setdefault('betas', (0.95, 0.98)) - optimizer = FusedNovoGrad(parameters, **opt_args) - - # bitsandbytes optimizers, require bitsandbytes to be installed - elif opt_lower == 'bnbsgd': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'bnbsgd8bit': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args) - elif opt_lower == 'bnbmomentum': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'bnbmomentum8bit': - opt_args.pop('eps', None) - optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args) - elif opt_lower == 'bnbadam': - optimizer = bnb.optim.Adam(parameters, **opt_args) - elif opt_lower == 'bnbadam8bit': - optimizer = bnb.optim.Adam8bit(parameters, **opt_args) - elif opt_lower == 'bnbadamw': - optimizer = bnb.optim.AdamW(parameters, **opt_args) - elif opt_lower == 'bnbadamw8bit': - optimizer = bnb.optim.AdamW8bit(parameters, **opt_args) - elif opt_lower == 'bnblamb': - optimizer = bnb.optim.LAMB(parameters, **opt_args) - elif opt_lower == 'bnblamb8bit': - optimizer = bnb.optim.LAMB8bit(parameters, **opt_args) - elif opt_lower == 'bnblars': - optimizer = bnb.optim.LARS(parameters, **opt_args) - elif opt_lower == 'bnblarsb8bit': - optimizer = bnb.optim.LAMB8bit(parameters, **opt_args) - elif opt_lower == 'bnblion': - optimizer = bnb.optim.Lion(parameters, **opt_args) - elif opt_lower == 'bnblion8bit': - optimizer = bnb.optim.Lion8bit(parameters, **opt_args) - - else: - assert False and "Invalid optimizer" - raise ValueError - - if len(opt_split) > 1: - if opt_split[0] == 'lookahead': - optimizer = Lookahead(optimizer) - - return optimizer diff --git a/timm/optim/radam.py b/timm/optim/radam.py index eb8d22e06c..d6f8d30a67 100644 --- a/timm/optim/radam.py +++ b/timm/optim/radam.py @@ -9,10 +9,21 @@ class RAdam(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, - buffer=[[None, None, None] for _ in range(10)]) + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)] + ) super(RAdam, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py index 0817887db3..8511b3b482 100644 --- a/timm/optim/rmsprop_tf.py +++ b/timm/optim/rmsprop_tf.py @@ -45,8 +45,18 @@ class RMSpropTF(Optimizer): """ - def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, - decoupled_decay=False, lr_in_momentum=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.9, + eps=1e-10, + weight_decay=0, + momentum=0., + centered=False, + decoupled_decay=False, + lr_in_momentum=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -59,8 +69,15 @@ def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, moment raise ValueError("Invalid alpha value: {}".format(alpha)) defaults = dict( - lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, - decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + decoupled_decay=decoupled_decay, + lr_in_momentum=lr_in_momentum, + ) super(RMSpropTF, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/timm/optim/sgdp.py b/timm/optim/sgdp.py index baf05fa55c..87b89f6f0b 100644 --- a/timm/optim/sgdp.py +++ b/timm/optim/sgdp.py @@ -17,11 +17,28 @@ class SGDP(Optimizer): - def __init__(self, params, lr=required, momentum=0, dampening=0, - weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + def __init__( + self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + eps=1e-8, + delta=0.1, + wd_ratio=0.1 + ): defaults = dict( - lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, - nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + eps=eps, + delta=delta, + wd_ratio=wd_ratio, + ) super(SGDP, self).__init__(params, defaults) @torch.no_grad() diff --git a/timm/optim/sgdw.py b/timm/optim/sgdw.py index b3d2c12f03..c5b44063d6 100644 --- a/timm/optim/sgdw.py +++ b/timm/optim/sgdw.py @@ -35,10 +35,15 @@ def __init__( raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( - lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, - maximize=maximize, foreach=foreach, - differentiable=differentiable) + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + ) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) diff --git a/train.py b/train.py index a82fa0a8c3..4302654fb5 100755 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse +import copy import importlib import json import logging @@ -554,6 +555,13 @@ def main(): **optimizer_kwargs(cfg=args), **args.opt_kwargs, ) + if utils.is_primary(args): + defaults = copy.deepcopy(optimizer.defaults) + defaults['weight_decay'] = args.weight_decay # this isn't stored in optimizer.defaults + defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()]) + logging.info( + f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}' + ) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing