diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 97cbfd22f..9e06e1a50 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -270,9 +270,9 @@ def create_optimizer( if param_group_fn: # run custom fn to generate param groups from nn.Module - parameters = param_group_fn(model_or_params) + params = param_group_fn(model_or_params) elif layer_decay is not None: - parameters = param_groups_layer_decay( + params = param_groups_layer_decay( model_or_params, weight_decay=weight_decay, layer_decay=layer_decay, @@ -281,17 +281,17 @@ def create_optimizer( ) weight_decay = 0. elif weight_decay and weight_decay_exclude_1d: - parameters = param_groups_weight_decay( + params = param_groups_weight_decay( model_or_params, weight_decay=weight_decay, no_weight_decay_list=no_weight_decay, ) weight_decay = 0. else: - parameters = model_or_params.parameters() + params = model_or_params.parameters() else: # pass parameters / parameter groups through to optimizer - parameters = model_or_params + params = model_or_params # Parse optimizer name opt_split = opt.lower().split('_') @@ -330,7 +330,7 @@ def create_optimizer( # Create optimizer opt_class = self.get_optimizer_class(opt_info, bind_defaults=False) - optimizer = opt_class(parameters, **opt_args) + optimizer = opt_class(params, **opt_args) # Apply Lookahead if requested if use_lookahead: @@ -685,12 +685,14 @@ def _register_bnb_optimizers(registry: OptimizerRegistry) -> None: 'bnblion', 'bitsandbytes.optim.Lion', description='bitsandbytes Lion', + has_eps=False, has_betas=True ), OptimInfo( 'bnblion8bit', 'bitsandbytes.optim.Lion8bit', description='bitsandbytes 8-bit Lion with dynamic quantization', + has_eps=False, has_betas=True ), OptimInfo(