From 75d676e7ea2b3640c9ce90778ab12e1f8ab7382b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 Nov 2024 16:45:01 -0800 Subject: [PATCH] Try to fix documentation build, add better docstrings to public optimizer api --- hfdocs/source/reference/optimizers.mdx | 10 +- timm/optim/_optim_factory.py | 146 ++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 7 deletions(-) diff --git a/hfdocs/source/reference/optimizers.mdx b/hfdocs/source/reference/optimizers.mdx index 637e7f0a74..212152fb95 100644 --- a/hfdocs/source/reference/optimizers.mdx +++ b/hfdocs/source/reference/optimizers.mdx @@ -6,22 +6,28 @@ This page contains the API reference documentation for learning rate optimizers ### Factory functions -[[autodoc]] timm.optim.optim_factory.create_optimizer -[[autodoc]] timm.optim.optim_factory.create_optimizer_v2 +[[autodoc]] timm.optim.create_optimizer_v2 +[[autodoc]] timm.optim.list_optimizers +[[autodoc]] timm.optim.get_optimizer_class ### Optimizer Classes [[autodoc]] timm.optim.adabelief.AdaBelief [[autodoc]] timm.optim.adafactor.Adafactor +[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision [[autodoc]] timm.optim.adahessian.Adahessian [[autodoc]] timm.optim.adamp.AdamP [[autodoc]] timm.optim.adamw.AdamW +[[autodoc]] timm.optim.adopt.Adopt [[autodoc]] timm.optim.lamb.Lamb [[autodoc]] timm.optim.lars.Lars +[[autodoc]] timm.optim.lion,Lion [[autodoc]] timm.optim.lookahead.Lookahead [[autodoc]] timm.optim.madgrad.MADGRAD [[autodoc]] timm.optim.nadam.Nadam +[[autodoc]] timm.optim.nadamw.NadamW [[autodoc]] timm.optim.nvnovograd.NvNovoGrad [[autodoc]] timm.optim.radam.RAdam [[autodoc]] timm.optim.rmsprop_tf.RMSpropTF [[autodoc]] timm.optim.sgdp.SGDP +[[autodoc]] timm.optim.sgdw.SGDW \ No newline at end of file diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index c38363f892..0ea20b6e69 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -124,7 +124,7 @@ def register_foreach_default(self, name: str) -> None: def list_optimizers( self, - filter: str = '', + filter: Union[str, List[str]] = '', exclude_filters: Optional[List[str]] = None, with_description: bool = False ) -> List[Union[str, Tuple[str, str]]]: @@ -141,7 +141,14 @@ def list_optimizers( names = sorted(self._optimizers.keys()) if filter: - names = [n for n in names if fnmatch(n, filter)] + if isinstance(filter, str): + filters = [filter] + else: + filters = filter + filtered_names = set() + for f in filters: + filtered_names.update(n for n in names if fnmatch(n, f)) + names = sorted(filtered_names) if exclude_filters: for exclude_filter in exclude_filters: @@ -149,6 +156,7 @@ def list_optimizers( if with_description: return [(name, self._optimizers[name].description) for name in names] + return names def get_optimizer_info(self, name: str) -> OptimInfo: @@ -718,11 +726,46 @@ def _register_default_optimizers() -> None: # Public API def list_optimizers( - filter: str = '', + filter: Union[str, List[str]] = '', exclude_filters: Optional[List[str]] = None, with_description: bool = False, ) -> List[Union[str, Tuple[str, str]]]: """List available optimizer names, optionally filtered. + + List all registered optimizers, with optional filtering using wildcard patterns. + Optimizers can be filtered using include and exclude patterns, and can optionally + return descriptions with each optimizer name. + + Args: + filter: Wildcard style filter string or list of filter strings + (e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for + Adam variants and 8-bit optimizers). Empty string means no filtering. + exclude_filters: Optional list of wildcard patterns to exclude. For example, + ['*8bit', 'fused*'] would exclude 8-bit and fused implementations. + with_description: If True, returns tuples of (name, description) instead of + just names. Descriptions provide brief explanations of optimizer characteristics. + + Returns: + If with_description is False: + List of optimizer names as strings (e.g., ['adam', 'adamw', ...]) + If with_description is True: + List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...]) + + Examples: + >>> list_optimizers() + ['adam', 'adamw', 'sgd', ...] + + >>> list_optimizers(['la*', 'nla*']) # List lamb & lars + ['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars'] + + >>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers + ['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam'] + + >>> list_optimizers(with_description=True) # Get descriptions + [('adabelief', 'Adapts learning rate based on gradient prediction error'), + ('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'), + ('adafactor', 'Memory-efficient implementation of Adam with factored gradients'), + ...] """ return default_registry.list_optimizers(filter, exclude_filters, with_description) @@ -731,7 +774,38 @@ def get_optimizer_class( name: str, bind_defaults: bool = False, ) -> Union[Type[optim.Optimizer], OptimizerCallable]: - """Get optimizer class by name with any defaults applied. + """Get optimizer class by name with option to bind default arguments. + + Retrieves the optimizer class or a partial function with default arguments bound. + This allows direct instantiation of optimizers with their default configurations + without going through the full factory. + + Args: + name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd') + bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound. + If False, returns the raw optimizer class. + + Returns: + If bind_defaults is False: + The optimizer class (e.g., torch.optim.Adam) + If bind_defaults is True: + A partial function with default arguments bound + + Raises: + ValueError: If optimizer name is not found in registry + + Examples: + >>> # Get raw optimizer class + >>> Adam = get_optimizer_class('adam') + >>> opt = Adam(model.parameters(), lr=1e-3) + + >>> # Get optimizer with defaults bound + >>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True) + >>> opt = AdamWithDefaults(model.parameters(), lr=1e-3) + + >>> # Get SGD with nesterov momentum default + >>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound + >>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9) """ return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults) @@ -748,7 +822,69 @@ def create_optimizer_v2( param_group_fn: Optional[Callable[[nn.Module], Params]] = None, **kwargs: Any, ) -> optim.Optimizer: - """Create an optimizer instance using the default registry.""" + """Create an optimizer instance via timm registry. + + Creates and configures an optimizer with appropriate parameter groups and settings. + Supports automatic parameter group creation for weight decay and layer-wise learning + rates, as well as custom parameter grouping. + + Args: + model_or_params: A PyTorch model or an iterable of parameters/parameter groups. + If a model is provided, parameters will be automatically extracted and grouped + based on the other arguments. + opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd'). + Use list_optimizers() to see available options. + lr: Learning rate. If None, will use the optimizer's default. + weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model. + momentum: Momentum factor for optimizers that support it. Only used if the + chosen optimizer accepts a momentum parameter. + foreach: Enable/disable foreach (multi-tensor) implementation if available. + If None, will use optimizer-specific defaults. + filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have + weight decay applied. Only used when model_or_params is a model and + weight_decay > 0. + layer_decay: Optional layer-wise learning rate decay factor. If provided, + learning rates will be scaled by layer_decay^(max_depth - layer_depth). + Only used when model_or_params is a model. + param_group_fn: Optional function to create custom parameter groups. + If provided, other parameter grouping options will be ignored. + **kwargs: Additional optimizer-specific arguments (e.g., betas for Adam). + + Returns: + Configured optimizer instance. + + Examples: + >>> # Basic usage with a model + >>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3) + + >>> # SGD with momentum and weight decay + >>> optimizer = create_optimizer_v2( + ... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4 + ... ) + + >>> # Adam with layer-wise learning rate decay + >>> optimizer = create_optimizer_v2( + ... model, 'adam', lr=1e-3, layer_decay=0.7 + ... ) + + >>> # Custom parameter groups + >>> def group_fn(model): + ... return [ + ... {'params': model.backbone.parameters(), 'lr': 1e-4}, + ... {'params': model.head.parameters(), 'lr': 1e-3} + ... ] + >>> optimizer = create_optimizer_v2( + ... model, 'sgd', param_group_fn=group_fn + ... ) + + Note: + Parameter group handling precedence: + 1. If param_group_fn is provided, it will be used exclusively + 2. If layer_decay is provided, layer-wise groups will be created + 3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created + 4. Otherwise, all parameters will be in a single group + """ + return default_registry.create_optimizer( model_or_params, opt=opt,