Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Punch cache_dir through model factory / builder / pretrain helpers #2356

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from torch import nn as nn
from torch.hub import load_state_dict_from_url
Expand Down Expand Up @@ -90,6 +91,7 @@ def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
load_fn: Optional[Callable] = None,
cache_dir: Optional[Union[str, Path]] = None,
):
r"""Loads a custom (read non .pth) weight file

Expand All @@ -102,9 +104,10 @@ def load_custom_pretrained(

Args:
model: The instantiated model to load weights into
pretrained_cfg (dict): Default pretrained model cfg
pretrained_cfg: Default pretrained model cfg
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
'load_pretrained' on the model will be called if it exists
cache_dir: Override model checkpoint cache dir for this load
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
Expand All @@ -122,6 +125,7 @@ def load_custom_pretrained(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS,
cache_dir=cache_dir,
)

if load_fn is not None:
Expand All @@ -139,17 +143,18 @@ def load_pretrained(
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
""" Load pretrained checkpoint

Args:
model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for target model
in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint

model: PyTorch module
pretrained_cfg: Configuration for pretrained weights / target dataset
num_classes: Number of classes for target model. Will adapt pretrained if different.
in_chans: Number of input chans for target model. Will adapt pretrained if different.
filter_fn: state_dict filter fn for load (takes state_dict, model as args)
strict: Strict load of checkpoint
cache_dir: Override model checkpoint cache dir for this load
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
Expand All @@ -173,6 +178,7 @@ def load_pretrained(
pretrained_loc,
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
cache_dir=cache_dir,
)
model.load_pretrained(pretrained_loc)
return
Expand All @@ -184,25 +190,27 @@ def load_pretrained(
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
weights_only=True,
model_dir=cache_dir,
)
except TypeError:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
model_dir=cache_dir,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
custom_load = pretrained_cfg.get('custom_load', False)
if isinstance(custom_load, str) and custom_load == 'hf':
load_custom_from_hf(*pretrained_loc, model)
load_custom_from_hf(*pretrained_loc, model, cache_dir=cache_dir)
return
else:
state_dict = load_state_dict_from_hf(*pretrained_loc)
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
else:
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
else:
model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
Expand Down Expand Up @@ -319,8 +327,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):

def resolve_pretrained_cfg(
variant: str,
pretrained_cfg=None,
pretrained_cfg_overlay=None,
pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
) -> PretrainedCfg:
model_with_tag = variant
pretrained_tag = None
Expand Down Expand Up @@ -362,6 +370,7 @@ def build_model_with_cfg(
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None,
cache_dir: Optional[Union[str, Path]] = None,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs,
):
Expand All @@ -374,16 +383,18 @@ def build_model_with_cfg(
* pruning config / model adaptation

Args:
model_cls: model class
variant: model variant name
pretrained: load pretrained weights
pretrained_cfg: model's pretrained weight/task config
model_cfg: model's architecture config
feature_cfg: feature extraction adapter config
pretrained_strict: load pretrained weights strictly
pretrained_filter_fn: filter callable for pretrained weights
kwargs_filter: kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
model_cls: Model class
variant: Model variant name
pretrained: Load the pretrained weights
pretrained_cfg: Model's pretrained weight/task config
pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
model_cfg: Model's architecture config
feature_cfg: Feature extraction adapter config
pretrained_strict: Load pretrained weights strictly
pretrained_filter_fn: Filter callable for pretrained weights
cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
kwargs_filter: Kwargs keys to filter (remove) before passing to model
**kwargs: Model args passed through to model __init__
"""
pruned = kwargs.pop('pruned', False)
features = False
Expand All @@ -395,8 +406,6 @@ def build_model_with_cfg(
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay
)

# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = pretrained_cfg.to_dict()

_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
Expand Down Expand Up @@ -431,6 +440,7 @@ def build_model_with_cfg(
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
cache_dir=cache_dir,
)

# Wrap the model in a feature extraction module if enabled
Expand Down
14 changes: 10 additions & 4 deletions timm/models/_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit

Expand Down Expand Up @@ -40,7 +41,8 @@ def create_model(
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = '',
checkpoint_path: Optional[Union[str, Path]] = None,
cache_dir: Optional[Union[str, Path]] = None,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
Expand All @@ -50,17 +52,17 @@ def create_model(

Lookup model's entrypoint function and pass relevant args to create a new model.

<Tip>
Tip:
**kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
and then the model class __init__(). kwargs values set to None are pruned before passing.
</Tip>

Args:
model_name: Name of model to instantiate.
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
pretrained_cfg: Pass in an external pretrained_cfg for model.
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints.
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
Expand Down Expand Up @@ -99,7 +101,10 @@ def create_model(
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
model_name,
cache_dir=cache_dir,
)
if model_args:
for k, v in model_args.items():
kwargs.setdefault(k, v)
Expand All @@ -118,6 +123,7 @@ def create_model(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
cache_dir=cache_dir,
**kwargs,
)

Expand Down
1 change: 0 additions & 1 deletion timm/models/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
import logging
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union

import torch
Expand Down
Loading
Loading