From b09f81c8cbe08a1ea38da8a1de09cf09991b82cd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 6 Dec 2024 08:58:02 -0800 Subject: [PATCH] More tweaks to docstrings for hub/builder --- timm/models/_builder.py | 40 ++++++++++++++++++++-------------------- timm/models/_factory.py | 2 +- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 72b237faa..482d370a9 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -107,6 +107,7 @@ def load_custom_pretrained( pretrained_cfg: Default pretrained model cfg load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named 'load_pretrained' on the model will be called if it exists + cache_dir: Override model checkpoint cache dir for this load """ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) if not pretrained_cfg: @@ -148,12 +149,12 @@ def load_pretrained( Args: model: PyTorch module - pretrained_cfg: configuration for pretrained weights / target dataset - num_classes: number of classes for target model - in_chans: number of input chans for target model + pretrained_cfg: Configuration for pretrained weights / target dataset + num_classes: Number of classes for target model. Will adapt pretrained if different. + in_chans: Number of input chans for target model. Will adapt pretrained if different. filter_fn: state_dict filter fn for load (takes state_dict, model as args) - strict: strict load of checkpoint - cache_dir: override path to cache dir for this load + strict: Strict load of checkpoint + cache_dir: Override model checkpoint cache dir for this load """ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) if not pretrained_cfg: @@ -326,8 +327,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter): def resolve_pretrained_cfg( variant: str, - pretrained_cfg=None, - pretrained_cfg_overlay=None, + pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, ) -> PretrainedCfg: model_with_tag = variant pretrained_tag = None @@ -382,17 +383,18 @@ def build_model_with_cfg( * pruning config / model adaptation Args: - model_cls: model class - variant: model variant name - pretrained: load pretrained weights - pretrained_cfg: model's pretrained weight/task config - model_cfg: model's architecture config - feature_cfg: feature extraction adapter config - pretrained_strict: load pretrained weights strictly - pretrained_filter_fn: filter callable for pretrained weights - cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations - kwargs_filter: kwargs to filter before passing to model - **kwargs: model args passed through to model __init__ + model_cls: Model class + variant: Model variant name + pretrained: Load the pretrained weights + pretrained_cfg: Model's pretrained weight/task config + pretrained_cfg_overlay: Entries that will override those in pretrained_cfg + model_cfg: Model's architecture config + feature_cfg: Feature extraction adapter config + pretrained_strict: Load pretrained weights strictly + pretrained_filter_fn: Filter callable for pretrained weights + cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints + kwargs_filter: Kwargs keys to filter (remove) before passing to model + **kwargs: Model args passed through to model __init__ """ pruned = kwargs.pop('pruned', False) features = False @@ -404,8 +406,6 @@ def build_model_with_cfg( pretrained_cfg=pretrained_cfg, pretrained_cfg_overlay=pretrained_cfg_overlay ) - - # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model pretrained_cfg = pretrained_cfg.to_dict() _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 2a62ee411..8fe730ead 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -62,7 +62,7 @@ def create_model( pretrained_cfg: Pass in an external pretrained_cfg for model. pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these. checkpoint_path: Path of checkpoint to load _after_ the model is initialized. - cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations + cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints. scriptable: Set layer config so that model is jit scriptable (not working for all models yet). exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).