From 71849b972a17bf43782c20f41bfb6a2e71f8ef67 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 4 Dec 2024 22:02:40 -0800 Subject: [PATCH 1/2] Punch cache_dir through model factory / builder / pretrain helpers. Improve some annotations in related code. --- timm/models/_builder.py | 36 +++++++++++------- timm/models/_factory.py | 14 +++++-- timm/models/_helpers.py | 1 - timm/models/_hub.py | 84 ++++++++++++++++++++++++++++++++--------- 4 files changed, 99 insertions(+), 36 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 8d124afe40..72b237faa8 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -2,7 +2,8 @@ import logging import os from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from torch import nn as nn from torch.hub import load_state_dict_from_url @@ -90,6 +91,7 @@ def load_custom_pretrained( model: nn.Module, pretrained_cfg: Optional[Dict] = None, load_fn: Optional[Callable] = None, + cache_dir: Optional[Union[str, Path]] = None, ): r"""Loads a custom (read non .pth) weight file @@ -102,9 +104,9 @@ def load_custom_pretrained( Args: model: The instantiated model to load weights into - pretrained_cfg (dict): Default pretrained model cfg + pretrained_cfg: Default pretrained model cfg load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named - 'laod_pretrained' on the model will be called if it exists + 'load_pretrained' on the model will be called if it exists """ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) if not pretrained_cfg: @@ -122,6 +124,7 @@ def load_custom_pretrained( pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS, + cache_dir=cache_dir, ) if load_fn is not None: @@ -139,17 +142,18 @@ def load_pretrained( in_chans: int = 3, filter_fn: Optional[Callable] = None, strict: bool = True, + cache_dir: Optional[Union[str, Path]] = None, ): """ Load pretrained checkpoint Args: - model (nn.Module) : PyTorch model module - pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset - num_classes (int): num_classes for target model - in_chans (int): in_chans for target model - filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) - strict (bool): strict load of checkpoint - + model: PyTorch module + pretrained_cfg: configuration for pretrained weights / target dataset + num_classes: number of classes for target model + in_chans: number of input chans for target model + filter_fn: state_dict filter fn for load (takes state_dict, model as args) + strict: strict load of checkpoint + cache_dir: override path to cache dir for this load """ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) if not pretrained_cfg: @@ -173,6 +177,7 @@ def load_pretrained( pretrained_loc, progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH, + cache_dir=cache_dir, ) model.load_pretrained(pretrained_loc) return @@ -184,6 +189,7 @@ def load_pretrained( progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH, weights_only=True, + model_dir=cache_dir, ) except TypeError: state_dict = load_state_dict_from_url( @@ -191,18 +197,19 @@ def load_pretrained( map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH, + model_dir=cache_dir, ) elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') if isinstance(pretrained_loc, (list, tuple)): custom_load = pretrained_cfg.get('custom_load', False) if isinstance(custom_load, str) and custom_load == 'hf': - load_custom_from_hf(*pretrained_loc, model) + load_custom_from_hf(*pretrained_loc, model, cache_dir=cache_dir) return else: - state_dict = load_state_dict_from_hf(*pretrained_loc) + state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir) else: - state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True) + state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir) else: model_name = pretrained_cfg.get('architecture', 'this model') raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.") @@ -362,6 +369,7 @@ def build_model_with_cfg( feature_cfg: Optional[Dict] = None, pretrained_strict: bool = True, pretrained_filter_fn: Optional[Callable] = None, + cache_dir: Optional[Union[str, Path]] = None, kwargs_filter: Optional[Tuple[str]] = None, **kwargs, ): @@ -382,6 +390,7 @@ def build_model_with_cfg( feature_cfg: feature extraction adapter config pretrained_strict: load pretrained weights strictly pretrained_filter_fn: filter callable for pretrained weights + cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations kwargs_filter: kwargs to filter before passing to model **kwargs: model args passed through to model __init__ """ @@ -431,6 +440,7 @@ def build_model_with_cfg( in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict, + cache_dir=cache_dir, ) # Wrap the model in a feature extraction module if enabled diff --git a/timm/models/_factory.py b/timm/models/_factory.py index bff15b9a64..2a62ee411d 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Any, Dict, Optional, Union from urllib.parse import urlsplit @@ -40,7 +41,8 @@ def create_model( pretrained: bool = False, pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, - checkpoint_path: str = '', + checkpoint_path: Optional[Union[str, Path]] = None, + cache_dir: Optional[Union[str, Path]] = None, scriptable: Optional[bool] = None, exportable: Optional[bool] = None, no_jit: Optional[bool] = None, @@ -50,10 +52,9 @@ def create_model( Lookup model's entrypoint function and pass relevant args to create a new model. - + Tip: **kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()`` and then the model class __init__(). kwargs values set to None are pruned before passing. - Args: model_name: Name of model to instantiate. @@ -61,6 +62,7 @@ def create_model( pretrained_cfg: Pass in an external pretrained_cfg for model. pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these. checkpoint_path: Path of checkpoint to load _after_ the model is initialized. + cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations scriptable: Set layer config so that model is jit scriptable (not working for all models yet). exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only). @@ -99,7 +101,10 @@ def create_model( assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' # For model names specified in the form `hf-hub:path/architecture_name@revision`, # load model weights + pretrained_cfg from Hugging Face hub. - pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name) + pretrained_cfg, model_name, model_args = load_model_config_from_hf( + model_name, + cache_dir=cache_dir, + ) if model_args: for k, v in model_args.items(): kwargs.setdefault(k, v) @@ -118,6 +123,7 @@ def create_model( pretrained=pretrained, pretrained_cfg=pretrained_cfg, pretrained_cfg_overlay=pretrained_cfg_overlay, + cache_dir=cache_dir, **kwargs, ) diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index 3622b104e3..ca5dc2445e 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -4,7 +4,6 @@ """ import logging import os -from collections import OrderedDict from typing import Any, Callable, Dict, Optional, Union import torch diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 19b3dcf074..4922dfc09a 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -5,7 +5,7 @@ from functools import partial from pathlib import Path from tempfile import TemporaryDirectory -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch from torch.hub import HASH_REGEX, download_url_to_file, urlparse @@ -53,7 +53,7 @@ HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version -def get_cache_dir(child_dir=''): +def get_cache_dir(child_dir: str = ''): """ Returns the location of the directory where models are cached (and creates it if necessary). """ @@ -68,13 +68,22 @@ def get_cache_dir(child_dir=''): return model_dir -def download_cached_file(url, check_hash=True, progress=False): +def download_cached_file( + url: Union[str, List[str], Tuple[str, str]], + check_hash: bool = True, + progress: bool = False, + cache_dir: Optional[Union[str, Path]] = None, +): if isinstance(url, (list, tuple)): url, filename = url else: parts = urlparse(url) filename = os.path.basename(parts.path) - cached_file = os.path.join(get_cache_dir(), filename) + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + else: + cache_dir = get_cache_dir() + cached_file = os.path.join(cache_dir, filename) if not os.path.exists(cached_file): _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None @@ -85,13 +94,19 @@ def download_cached_file(url, check_hash=True, progress=False): return cached_file -def check_cached_file(url, check_hash=True): +def check_cached_file( + url: Union[str, List[str], Tuple[str, str]], + check_hash: bool = True, + cache_dir: Optional[Union[str, Path]] = None, +): if isinstance(url, (list, tuple)): url, filename = url else: parts = urlparse(url) filename = os.path.basename(parts.path) - cached_file = os.path.join(get_cache_dir(), filename) + if not cache_dir: + cache_dir = get_cache_dir() + cached_file = os.path.join(cache_dir, filename) if os.path.exists(cached_file): if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] @@ -105,7 +120,7 @@ def check_cached_file(url, check_hash=True): return False -def has_hf_hub(necessary=False): +def has_hf_hub(necessary: bool = False): if not _has_hf_hub and necessary: # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( @@ -122,20 +137,32 @@ def hf_split(hf_id: str): return hf_model_id, hf_revision -def load_cfg_from_json(json_file: Union[str, os.PathLike]): +def load_cfg_from_json(json_file: Union[str, Path]): with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return json.loads(text) -def download_from_hf(model_id: str, filename: str): +def download_from_hf( + model_id: str, + filename: str, + cache_dir: Optional[Union[str, Path]] = None, +): hf_model_id, hf_revision = hf_split(model_id) - return hf_hub_download(hf_model_id, filename, revision=hf_revision) + return hf_hub_download( + hf_model_id, + filename, + revision=hf_revision, + cache_dir=cache_dir, + ) -def load_model_config_from_hf(model_id: str): +def load_model_config_from_hf( + model_id: str, + cache_dir: Optional[Union[str, Path]] = None, +): assert has_hf_hub(True) - cached_file = download_from_hf(model_id, 'config.json') + cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir) hf_config = load_cfg_from_json(cached_file) if 'pretrained_cfg' not in hf_config: @@ -172,6 +199,7 @@ def load_state_dict_from_hf( model_id: str, filename: str = HF_WEIGHTS_NAME, weights_only: bool = False, + cache_dir: Optional[Union[str, Path]] = None, ): assert has_hf_hub(True) hf_model_id, hf_revision = hf_split(model_id) @@ -180,7 +208,12 @@ def load_state_dict_from_hf( if _has_safetensors: for safe_filename in _get_safe_alternatives(filename): try: - cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision) + cached_safe_file = hf_hub_download( + repo_id=hf_model_id, + filename=safe_filename, + revision=hf_revision, + cache_dir=cache_dir, + ) _logger.info( f"[{model_id}] Safe alternative available for '{filename}' " f"(as '{safe_filename}'). Loading weights using safetensors.") @@ -189,7 +222,12 @@ def load_state_dict_from_hf( pass # Otherwise, load using pytorch.load - cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) + cached_file = hf_hub_download( + hf_model_id, + filename=filename, + revision=hf_revision, + cache_dir=cache_dir, + ) _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") try: state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only) @@ -198,15 +236,25 @@ def load_state_dict_from_hf( return state_dict -def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module): +def load_custom_from_hf( + model_id: str, + filename: str, + model: torch.nn.Module, + cache_dir: Optional[Union[str, Path]] = None, +): assert has_hf_hub(True) hf_model_id, hf_revision = hf_split(model_id) - cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) + cached_file = hf_hub_download( + hf_model_id, + filename=filename, + revision=hf_revision, + cache_dir=cache_dir, + ) return model.load_pretrained(cached_file) def save_config_for_hf( - model, + model: torch.nn.Module, config_path: str, model_config: Optional[dict] = None, model_args: Optional[dict] = None @@ -255,7 +303,7 @@ def save_config_for_hf( def save_for_hf( - model, + model: torch.nn.Module, save_directory: str, model_config: Optional[dict] = None, model_args: Optional[dict] = None, From b09f81c8cbe08a1ea38da8a1de09cf09991b82cd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 6 Dec 2024 08:58:02 -0800 Subject: [PATCH 2/2] More tweaks to docstrings for hub/builder --- timm/models/_builder.py | 40 ++++++++++++++++++++-------------------- timm/models/_factory.py | 2 +- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 72b237faa8..482d370a94 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -107,6 +107,7 @@ def load_custom_pretrained( pretrained_cfg: Default pretrained model cfg load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named 'load_pretrained' on the model will be called if it exists + cache_dir: Override model checkpoint cache dir for this load """ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) if not pretrained_cfg: @@ -148,12 +149,12 @@ def load_pretrained( Args: model: PyTorch module - pretrained_cfg: configuration for pretrained weights / target dataset - num_classes: number of classes for target model - in_chans: number of input chans for target model + pretrained_cfg: Configuration for pretrained weights / target dataset + num_classes: Number of classes for target model. Will adapt pretrained if different. + in_chans: Number of input chans for target model. Will adapt pretrained if different. filter_fn: state_dict filter fn for load (takes state_dict, model as args) - strict: strict load of checkpoint - cache_dir: override path to cache dir for this load + strict: Strict load of checkpoint + cache_dir: Override model checkpoint cache dir for this load """ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) if not pretrained_cfg: @@ -326,8 +327,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter): def resolve_pretrained_cfg( variant: str, - pretrained_cfg=None, - pretrained_cfg_overlay=None, + pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, ) -> PretrainedCfg: model_with_tag = variant pretrained_tag = None @@ -382,17 +383,18 @@ def build_model_with_cfg( * pruning config / model adaptation Args: - model_cls: model class - variant: model variant name - pretrained: load pretrained weights - pretrained_cfg: model's pretrained weight/task config - model_cfg: model's architecture config - feature_cfg: feature extraction adapter config - pretrained_strict: load pretrained weights strictly - pretrained_filter_fn: filter callable for pretrained weights - cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations - kwargs_filter: kwargs to filter before passing to model - **kwargs: model args passed through to model __init__ + model_cls: Model class + variant: Model variant name + pretrained: Load the pretrained weights + pretrained_cfg: Model's pretrained weight/task config + pretrained_cfg_overlay: Entries that will override those in pretrained_cfg + model_cfg: Model's architecture config + feature_cfg: Feature extraction adapter config + pretrained_strict: Load pretrained weights strictly + pretrained_filter_fn: Filter callable for pretrained weights + cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints + kwargs_filter: Kwargs keys to filter (remove) before passing to model + **kwargs: Model args passed through to model __init__ """ pruned = kwargs.pop('pruned', False) features = False @@ -404,8 +406,6 @@ def build_model_with_cfg( pretrained_cfg=pretrained_cfg, pretrained_cfg_overlay=pretrained_cfg_overlay ) - - # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model pretrained_cfg = pretrained_cfg.to_dict() _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 2a62ee411d..8fe730ead9 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -62,7 +62,7 @@ def create_model( pretrained_cfg: Pass in an external pretrained_cfg for model. pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these. checkpoint_path: Path of checkpoint to load _after_ the model is initialized. - cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations + cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints. scriptable: Set layer config so that model is jit scriptable (not working for all models yet). exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).