diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index d95935d875..e076bd5b8f 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -64,7 +64,7 @@ def download_from_hf_hub( save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. - tokenizers_only (bool): If true, only download tokenzier files. + tokenizer_only (bool): If true, only download tokenizer files. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` environment variable. @@ -103,7 +103,7 @@ def download_from_hf_hub( ' Please make sure the repo contains either safetensors or pytorch weights.' ) - allow_patterns = TOKENIZER_FILES if tokenizers_only else None + allow_patterns = TOKENIZER_FILES if tokenizer_only else None download_start = time.time() hf_hub.snapshot_download(model, @@ -232,7 +232,7 @@ def download_from_oras(model: str, config_file: str, credentials_dir: str, save_dir: str, - tokenizer_only: bool, + tokenizer_only: bool = False, concurrency: int = 10): """Download from an OCI-compliant registry using oras. diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py index 40dfc4775d..cdb1da7506 100644 --- a/scripts/misc/download_model.py +++ b/scripts/misc/download_model.py @@ -11,7 +11,7 @@ --credentials-dir --save-dir Download from an HTTP file server: - python download_model.py http --url https://server.com/path --save-dir + python download_model.py http --url https://server.com/models/mosaicml/mpt-7b/ --save-dir Download from an HTTP file server with fallback to Hugging Face Hub: python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir \ @@ -116,8 +116,8 @@ def parse_args() -> argparse.Namespace: download_from_hf_hub(args.model, save_dir=args.save_dir, token=args.token, - tokenizers_only=args.tokenizer_only, + tokenizer_only=args.tokenizer_only, prefer_safetensors=args.prefer_safetensors) elif download_from == 'oras': download_from_oras(args.model, args.config_file, args.credentials_dir, - args.save_dir, args.tokenizer_only, args.concurrency) + args.save_dir, tokenizers_only=args.tokenizer_only, args.concurrency)