Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Jerry Chen <[email protected]>
  • Loading branch information
irenedea and jerrychen109 authored Jan 23, 2024
1 parent da27b1f commit 7d25936
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def download_from_hf_hub(
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
tokenizers_only (bool): If true, only download tokenzier files.
tokenizer_only (bool): If true, only download tokenizer files.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.
Expand Down Expand Up @@ -103,7 +103,7 @@ def download_from_hf_hub(
' Please make sure the repo contains either safetensors or pytorch weights.'
)

allow_patterns = TOKENIZER_FILES if tokenizers_only else None
allow_patterns = TOKENIZER_FILES if tokenizer_only else None

download_start = time.time()
hf_hub.snapshot_download(model,
Expand Down Expand Up @@ -232,7 +232,7 @@ def download_from_oras(model: str,
config_file: str,
credentials_dir: str,
save_dir: str,
tokenizer_only: bool,
tokenizer_only: bool = False,
concurrency: int = 10):
"""Download from an OCI-compliant registry using oras.
Expand Down
6 changes: 3 additions & 3 deletions scripts/misc/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
--credentials-dir <credentials_dir> --save-dir <save_dir>
Download from an HTTP file server:
python download_model.py http --url https://server.com/path --save-dir <save_dir>
python download_model.py http --url https://server.com/models/mosaicml/mpt-7b/ --save-dir <save_dir>
Download from an HTTP file server with fallback to Hugging Face Hub:
python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir <save_dir> \
Expand Down Expand Up @@ -116,8 +116,8 @@ def parse_args() -> argparse.Namespace:
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token,
tokenizers_only=args.tokenizer_only,
tokenizer_only=args.tokenizer_only,
prefer_safetensors=args.prefer_safetensors)
elif download_from == 'oras':
download_from_oras(args.model, args.config_file, args.credentials_dir,
args.save_dir, args.tokenizer_only, args.concurrency)
args.save_dir, tokenizers_only=args.tokenizer_only, args.concurrency)

0 comments on commit 7d25936

Please sign in to comment.