Skip to content

Commit

Permalink
Improve cache_dir behaviour
Browse files Browse the repository at this point in the history
* make use of cache_dir for HF tokenizer wrapper explicit
* add missing use of cache_dir for an _get_hf_config call
* add cache_dir as argument to train/val script
  • Loading branch information
rwightman committed Oct 24, 2024
1 parent 1b01224 commit d11e54a
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 11 deletions.
40 changes: 34 additions & 6 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,25 @@ def add_model_config(path):


def get_model_config(model_name):
""" Fetch model config from builtin (local library) configs.
"""
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None


def _get_hf_config(model_id, cache_dir=None):
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
def _get_hf_config(
model_id: str,
cache_dir: Optional[str] = None,
):
""" Fetch model config from HuggingFace Hub.
"""
config_path = download_pretrained_from_hf(
model_id,
filename='open_clip_config.json',
cache_dir=cache_dir,
)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
Expand All @@ -83,16 +94,18 @@ def _get_hf_config(model_id, cache_dir=None):
def get_tokenizer(
model_name: str = '',
context_length: Optional[int] = None,
cache_dir: Optional[str] = None,
**kwargs,
):
if model_name.startswith(HF_HUB_PREFIX):
model_name = model_name[len(HF_HUB_PREFIX):]
try:
config = _get_hf_config(model_name)['model_cfg']
config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg']
except Exception:
tokenizer = HFTokenizer(
model_name,
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
cache_dir=cache_dir,
**kwargs,
)
return tokenizer
Expand All @@ -113,6 +126,7 @@ def get_tokenizer(
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
cache_dir=cache_dir,
**tokenizer_kwargs,
)
else:
Expand Down Expand Up @@ -265,7 +279,7 @@ def create_model(
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config = _get_hf_config(model_id, cache_dir)
config = _get_hf_config(model_id, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
model_cfg = config['model_cfg']
pretrained_hf = False # override, no need to load original HF text weights
Expand Down Expand Up @@ -456,10 +470,16 @@ def create_model_and_transforms(
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
load_weights_only: bool = True,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
{},
mean=image_mean,
std=image_std,
interpolation=image_interpolation,
resize_mode=image_resize_mode,
)

model = create_model(
model_name,
Expand All @@ -476,6 +496,7 @@ def create_model_and_transforms(
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
load_weights_only=load_weights_only,
**model_kwargs,
)

Expand Down Expand Up @@ -509,10 +530,16 @@ def create_model_from_pretrained(
image_resize_mode: Optional[str] = None, # only effective for inference
return_transform: bool = True,
cache_dir: Optional[str] = None,
load_weights_only: bool = True,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
{},
mean=image_mean,
std=image_std,
interpolation=image_interpolation,
resize_mode=image_resize_mode,
)

model = create_model(
model_name,
Expand All @@ -526,6 +553,7 @@ def create_model_from_pretrained(
force_preprocess_cfg=force_preprocess_cfg,
cache_dir=cache_dir,
require_pretrained=True,
load_weights_only=load_weights_only,
**model_kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions src/open_clip/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def get_pretrained_url(model: str, tag: str):

def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
cache_dir: Optional[str] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
Expand Down Expand Up @@ -712,7 +712,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME

if filename not in (HF_WEIGHTS_NAME,) and filename.endswith(".bin") or filename.endswith(".pth"):
if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")):
yield filename[:-4] + ".safetensors"


Expand Down Expand Up @@ -750,7 +750,7 @@ def download_pretrained_from_hf(
)
return cached_file # Return the path to the downloaded file if successful
except Exception as e:
raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {e}")
raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}")


def download_pretrained(
Expand Down
6 changes: 5 additions & 1 deletion src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,11 @@ def __init__(
clean: str = 'whitespace',
strip_sep_token: bool = False,
language: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs
):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir, **kwargs)
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
Expand Down Expand Up @@ -462,6 +463,9 @@ def set_language(self, src_lang):

class SigLipTokenizer:
"""HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers
into OpenCLIP. Leaving code here in case future models use new tokenizers.
"""
VOCAB_FILES = {
# english, vocab_size=32_000
Expand Down
4 changes: 3 additions & 1 deletion src/open_clip_train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def main(args):
aug_cfg=args.aug_cfg,
pretrained_image=args.pretrained_image,
output_dict=True,
cache_dir=args.cache_dir,
**model_kwargs,
)
if args.distill:
Expand All @@ -246,6 +247,7 @@ def main(args):
device=device,
precision=args.precision,
output_dict=True,
cache_dir=args.cache_dir,
)
if args.use_bnb_linear is not None:
print('=> using a layer from bitsandbytes.\n'
Expand Down Expand Up @@ -357,7 +359,7 @@ def main(args):
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")

# initialize datasets
tokenizer = get_tokenizer(args.model)
tokenizer = get_tokenizer(args.model, cache_dir=args.cache_dir)
data = get_data(
args,
(preprocess_train, preprocess_val),
Expand Down
6 changes: 6 additions & 0 deletions src/open_clip_train/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def parse_args(args):
default=None,
help="Path to imagenet v2 for conducting zero shot evaluation.",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="Override system default cache path for model & tokenizer file downloads.",
)
parser.add_argument(
"--logs",
type=str,
Expand Down

0 comments on commit d11e54a

Please sign in to comment.