diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 88dc03aef..358b51fb0 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -67,14 +67,25 @@ def add_model_config(path): def get_model_config(model_name): + """ Fetch model config from builtin (local library) configs. + """ if model_name in _MODEL_CONFIGS: return deepcopy(_MODEL_CONFIGS[model_name]) else: return None -def _get_hf_config(model_id, cache_dir=None): - config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) +def _get_hf_config( + model_id: str, + cache_dir: Optional[str] = None, +): + """ Fetch model config from HuggingFace Hub. + """ + config_path = download_pretrained_from_hf( + model_id, + filename='open_clip_config.json', + cache_dir=cache_dir, + ) with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) return config @@ -83,16 +94,18 @@ def _get_hf_config(model_id, cache_dir=None): def get_tokenizer( model_name: str = '', context_length: Optional[int] = None, + cache_dir: Optional[str] = None, **kwargs, ): if model_name.startswith(HF_HUB_PREFIX): model_name = model_name[len(HF_HUB_PREFIX):] try: - config = _get_hf_config(model_name)['model_cfg'] + config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg'] except Exception: tokenizer = HFTokenizer( model_name, context_length=context_length or DEFAULT_CONTEXT_LENGTH, + cache_dir=cache_dir, **kwargs, ) return tokenizer @@ -113,6 +126,7 @@ def get_tokenizer( tokenizer = HFTokenizer( text_config['hf_tokenizer_name'], context_length=context_length, + cache_dir=cache_dir, **tokenizer_kwargs, ) else: @@ -265,7 +279,7 @@ def create_model( if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) - config = _get_hf_config(model_id, cache_dir) + config = _get_hf_config(model_id, cache_dir=cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) model_cfg = config['model_cfg'] pretrained_hf = False # override, no need to load original HF text weights @@ -456,10 +470,16 @@ def create_model_and_transforms( pretrained_hf: bool = True, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, + load_weights_only: bool = True, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + {}, + mean=image_mean, + std=image_std, + interpolation=image_interpolation, + resize_mode=image_resize_mode, + ) model = create_model( model_name, @@ -476,6 +496,7 @@ def create_model_and_transforms( pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, + load_weights_only=load_weights_only, **model_kwargs, ) @@ -509,10 +530,16 @@ def create_model_from_pretrained( image_resize_mode: Optional[str] = None, # only effective for inference return_transform: bool = True, cache_dir: Optional[str] = None, + load_weights_only: bool = True, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + {}, + mean=image_mean, + std=image_std, + interpolation=image_interpolation, + resize_mode=image_resize_mode, + ) model = create_model( model_name, @@ -526,6 +553,7 @@ def create_model_from_pretrained( force_preprocess_cfg=force_preprocess_cfg, cache_dir=cache_dir, require_pretrained=True, + load_weights_only=load_weights_only, **model_kwargs, ) diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index aac87619d..24c27ef3a 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -651,7 +651,7 @@ def get_pretrained_url(model: str, tag: str): def download_pretrained_from_url( url: str, - cache_dir: Union[str, None] = None, + cache_dir: Optional[str] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") @@ -712,7 +712,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]: if filename == HF_WEIGHTS_NAME: yield HF_SAFE_WEIGHTS_NAME - if filename not in (HF_WEIGHTS_NAME,) and filename.endswith(".bin") or filename.endswith(".pth"): + if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): yield filename[:-4] + ".safetensors" @@ -750,7 +750,7 @@ def download_pretrained_from_hf( ) return cached_file # Return the path to the downloaded file if successful except Exception as e: - raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {e}") + raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") def download_pretrained( diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 3b762c2fa..872c1833b 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -410,10 +410,11 @@ def __init__( clean: str = 'whitespace', strip_sep_token: bool = False, language: Optional[str] = None, + cache_dir: Optional[str] = None, **kwargs ): from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir, **kwargs) set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None) if callable(set_lang_fn): self.set_lang_fn = set_lang_fn @@ -462,6 +463,9 @@ def set_language(self, src_lang): class SigLipTokenizer: """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs + + NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers + into OpenCLIP. Leaving code here in case future models use new tokenizers. """ VOCAB_FILES = { # english, vocab_size=32_000 diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 1aa0750fc..7c244ae35 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -236,6 +236,7 @@ def main(args): aug_cfg=args.aug_cfg, pretrained_image=args.pretrained_image, output_dict=True, + cache_dir=args.cache_dir, **model_kwargs, ) if args.distill: @@ -246,6 +247,7 @@ def main(args): device=device, precision=args.precision, output_dict=True, + cache_dir=args.cache_dir, ) if args.use_bnb_linear is not None: print('=> using a layer from bitsandbytes.\n' @@ -357,7 +359,7 @@ def main(args): logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets - tokenizer = get_tokenizer(args.model) + tokenizer = get_tokenizer(args.model, cache_dir=args.cache_dir) data = get_data( args, (preprocess_train, preprocess_val), diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index b36ae7bec..2d94b7e21 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -101,6 +101,12 @@ def parse_args(args): default=None, help="Path to imagenet v2 for conducting zero shot evaluation.", ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Override system default cache path for model & tokenizer file downloads.", + ) parser.add_argument( "--logs", type=str,