Skip to content

Commit

Permalink
Try to fix documentation build, add better docstrings to public optim…
Browse files Browse the repository at this point in the history
…izer api
  • Loading branch information
rwightman committed Nov 13, 2024
1 parent c8b4511 commit 75d676e
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 7 deletions.
10 changes: 8 additions & 2 deletions hfdocs/source/reference/optimizers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@ This page contains the API reference documentation for learning rate optimizers

### Factory functions

[[autodoc]] timm.optim.optim_factory.create_optimizer
[[autodoc]] timm.optim.optim_factory.create_optimizer_v2
[[autodoc]] timm.optim.create_optimizer_v2
[[autodoc]] timm.optim.list_optimizers
[[autodoc]] timm.optim.get_optimizer_class

### Optimizer Classes

[[autodoc]] timm.optim.adabelief.AdaBelief
[[autodoc]] timm.optim.adafactor.Adafactor
[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision
[[autodoc]] timm.optim.adahessian.Adahessian
[[autodoc]] timm.optim.adamp.AdamP
[[autodoc]] timm.optim.adamw.AdamW
[[autodoc]] timm.optim.adopt.Adopt
[[autodoc]] timm.optim.lamb.Lamb
[[autodoc]] timm.optim.lars.Lars
[[autodoc]] timm.optim.lion,Lion
[[autodoc]] timm.optim.lookahead.Lookahead
[[autodoc]] timm.optim.madgrad.MADGRAD
[[autodoc]] timm.optim.nadam.Nadam
[[autodoc]] timm.optim.nadamw.NadamW
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
[[autodoc]] timm.optim.radam.RAdam
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
[[autodoc]] timm.optim.sgdp.SGDP
[[autodoc]] timm.optim.sgdw.SGDW
146 changes: 141 additions & 5 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def register_foreach_default(self, name: str) -> None:

def list_optimizers(
self,
filter: str = '',
filter: Union[str, List[str]] = '',
exclude_filters: Optional[List[str]] = None,
with_description: bool = False
) -> List[Union[str, Tuple[str, str]]]:
Expand All @@ -141,14 +141,22 @@ def list_optimizers(
names = sorted(self._optimizers.keys())

if filter:
names = [n for n in names if fnmatch(n, filter)]
if isinstance(filter, str):
filters = [filter]
else:
filters = filter
filtered_names = set()
for f in filters:
filtered_names.update(n for n in names if fnmatch(n, f))
names = sorted(filtered_names)

if exclude_filters:
for exclude_filter in exclude_filters:
names = [n for n in names if not fnmatch(n, exclude_filter)]

if with_description:
return [(name, self._optimizers[name].description) for name in names]

return names

def get_optimizer_info(self, name: str) -> OptimInfo:
Expand Down Expand Up @@ -718,11 +726,46 @@ def _register_default_optimizers() -> None:
# Public API

def list_optimizers(
filter: str = '',
filter: Union[str, List[str]] = '',
exclude_filters: Optional[List[str]] = None,
with_description: bool = False,
) -> List[Union[str, Tuple[str, str]]]:
"""List available optimizer names, optionally filtered.
List all registered optimizers, with optional filtering using wildcard patterns.
Optimizers can be filtered using include and exclude patterns, and can optionally
return descriptions with each optimizer name.
Args:
filter: Wildcard style filter string or list of filter strings
(e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for
Adam variants and 8-bit optimizers). Empty string means no filtering.
exclude_filters: Optional list of wildcard patterns to exclude. For example,
['*8bit', 'fused*'] would exclude 8-bit and fused implementations.
with_description: If True, returns tuples of (name, description) instead of
just names. Descriptions provide brief explanations of optimizer characteristics.
Returns:
If with_description is False:
List of optimizer names as strings (e.g., ['adam', 'adamw', ...])
If with_description is True:
List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...])
Examples:
>>> list_optimizers()
['adam', 'adamw', 'sgd', ...]
>>> list_optimizers(['la*', 'nla*']) # List lamb & lars
['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars']
>>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers
['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam']
>>> list_optimizers(with_description=True) # Get descriptions
[('adabelief', 'Adapts learning rate based on gradient prediction error'),
('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'),
('adafactor', 'Memory-efficient implementation of Adam with factored gradients'),
...]
"""
return default_registry.list_optimizers(filter, exclude_filters, with_description)

Expand All @@ -731,7 +774,38 @@ def get_optimizer_class(
name: str,
bind_defaults: bool = False,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
"""Get optimizer class by name with any defaults applied.
"""Get optimizer class by name with option to bind default arguments.
Retrieves the optimizer class or a partial function with default arguments bound.
This allows direct instantiation of optimizers with their default configurations
without going through the full factory.
Args:
name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd')
bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound.
If False, returns the raw optimizer class.
Returns:
If bind_defaults is False:
The optimizer class (e.g., torch.optim.Adam)
If bind_defaults is True:
A partial function with default arguments bound
Raises:
ValueError: If optimizer name is not found in registry
Examples:
>>> # Get raw optimizer class
>>> Adam = get_optimizer_class('adam')
>>> opt = Adam(model.parameters(), lr=1e-3)
>>> # Get optimizer with defaults bound
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
>>> # Get SGD with nesterov momentum default
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
"""
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)

Expand All @@ -748,7 +822,69 @@ def create_optimizer_v2(
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
**kwargs: Any,
) -> optim.Optimizer:
"""Create an optimizer instance using the default registry."""
"""Create an optimizer instance via timm registry.
Creates and configures an optimizer with appropriate parameter groups and settings.
Supports automatic parameter group creation for weight decay and layer-wise learning
rates, as well as custom parameter grouping.
Args:
model_or_params: A PyTorch model or an iterable of parameters/parameter groups.
If a model is provided, parameters will be automatically extracted and grouped
based on the other arguments.
opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd').
Use list_optimizers() to see available options.
lr: Learning rate. If None, will use the optimizer's default.
weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model.
momentum: Momentum factor for optimizers that support it. Only used if the
chosen optimizer accepts a momentum parameter.
foreach: Enable/disable foreach (multi-tensor) implementation if available.
If None, will use optimizer-specific defaults.
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
weight decay applied. Only used when model_or_params is a model and
weight_decay > 0.
layer_decay: Optional layer-wise learning rate decay factor. If provided,
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
Only used when model_or_params is a model.
param_group_fn: Optional function to create custom parameter groups.
If provided, other parameter grouping options will be ignored.
**kwargs: Additional optimizer-specific arguments (e.g., betas for Adam).
Returns:
Configured optimizer instance.
Examples:
>>> # Basic usage with a model
>>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3)
>>> # SGD with momentum and weight decay
>>> optimizer = create_optimizer_v2(
... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4
... )
>>> # Adam with layer-wise learning rate decay
>>> optimizer = create_optimizer_v2(
... model, 'adam', lr=1e-3, layer_decay=0.7
... )
>>> # Custom parameter groups
>>> def group_fn(model):
... return [
... {'params': model.backbone.parameters(), 'lr': 1e-4},
... {'params': model.head.parameters(), 'lr': 1e-3}
... ]
>>> optimizer = create_optimizer_v2(
... model, 'sgd', param_group_fn=group_fn
... )
Note:
Parameter group handling precedence:
1. If param_group_fn is provided, it will be used exclusively
2. If layer_decay is provided, layer-wise groups will be created
3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created
4. Otherwise, all parameters will be in a single group
"""

return default_registry.create_optimizer(
model_or_params,
opt=opt,
Expand Down

0 comments on commit 75d676e

Please sign in to comment.