Skip to content

Commit

Permalink
Add tokenizer-only flag to only download tokenizers from HF or oras (
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Jan 23, 2024
1 parent 02c44ad commit 07d6db3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
26 changes: 20 additions & 6 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
]
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
TOKENIZER_FILES = [
'special_tokens_map.json',
'tokenizer.json',
'tokenizer.model',
'tokenizer_config.json',
]

ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
ORAS_CLI = 'oras'
Expand All @@ -45,6 +51,7 @@ def download_from_hf_hub(
model: str,
save_dir: str,
prefer_safetensors: bool = True,
tokenizer_only: bool = False,
token: Optional[str] = None,
):
"""Downloads model files from a Hugging Face Hub model repo.
Expand All @@ -57,6 +64,7 @@ def download_from_hf_hub(
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
tokenizer_only (bool): If true, only download tokenizer files.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.
Expand Down Expand Up @@ -95,10 +103,13 @@ def download_from_hf_hub(
' Please make sure the repo contains either safetensors or pytorch weights.'
)

allow_patterns = TOKENIZER_FILES if tokenizer_only else None

download_start = time.time()
hf_hub.snapshot_download(model,
local_dir=save_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
token=token)
download_duration = time.time() - download_start
log.info(
Expand Down Expand Up @@ -221,16 +232,18 @@ def download_from_oras(model: str,
config_file: str,
credentials_dir: str,
save_dir: str,
tokenizer_only: bool = False,
concurrency: int = 10):
"""Download from an OCI-compliant registry using oras.
Args:
model: The name of the model to download.
config_file: Path to a YAML config file that maps model names to registry paths.
credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three
model (str): The name of the model to download.
config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths.
credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
save_dir: Path to the directory where files will be downloaded.
concurrency: The number of concurrent downloads to run.
save_dir (str): Path to the directory where files will be downloaded.
tokenizer_only (bool): If true, only download the tokenzier files.
concurrency (int): The number of concurrent downloads to run.
"""
if shutil.which(ORAS_CLI) is None:
raise Exception(
Expand All @@ -253,7 +266,8 @@ def _read_secrets_file(secret_file_path: str,):
with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())

path = configs['models'][model]
config_type = 'tokenizers' if tokenizer_only else 'models'
path = configs[config_type][model]
registry = secrets['registry']

def get_oras_cmd(username: Optional[str] = None,
Expand Down
20 changes: 16 additions & 4 deletions scripts/misc/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
python download_model.py hf --model mosaicml/mpt-7b --save-dir <save_dir> --token <token>
Download from ORAS registry:
python download_model.py oras --registry <registry> --path mosaicml/mpt-7b --save-dir <save_dir>
python download_model.py oras --model mosaicml/mpt-7b --config-file <config_file> \
--credentials-dir <credentials_dir> --save-dir <save_dir>
Download from an HTTP file server:
python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir <save_dir>
python download_model.py http --url https://server.com/models/mosaicml/mpt-7b/ --save-dir <save_dir>
Download from an HTTP file server with fallback to Hugging Face Hub:
python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir <save_dir> \
Expand Down Expand Up @@ -56,6 +57,9 @@ def parse_args() -> argparse.Namespace:

base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument('--save-dir', type=str, required=True)
base_parser.add_argument('--tokenizer-only',
default=False,
action='store_true')

# Add subparser for downloading from Hugging Face Hub.
hf_parser = subparsers.add_parser('hf', parents=[base_parser])
Expand Down Expand Up @@ -85,6 +89,9 @@ def parse_args() -> argparse.Namespace:
download_from = args.download_from

if download_from == 'http':
if args.tokenizer_only:
raise ValueError(
'tokenizer-only is not currently supported for http.')
try:
download_from_http_fileserver(args.url, args.save_dir,
args.ignore_cert)
Expand All @@ -109,7 +116,12 @@ def parse_args() -> argparse.Namespace:
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token,
tokenizer_only=args.tokenizer_only,
prefer_safetensors=args.prefer_safetensors)
elif download_from == 'oras':
download_from_oras(args.model, args.config_file, args.credentials_dir,
args.save_dir, args.concurrency)
download_from_oras(args.model,
args.config_file,
args.credentials_dir,
args.save_dir,
tokenizer_only=args.tokenizer_only,
concurrency=args.concurrency)
1 change: 1 addition & 0 deletions tests/utils/test_model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock,
mock_snapshot_download.assert_called_once_with(
test_repo_id,
local_dir=save_dir,
allow_patterns=None,
ignore_patterns=expected_ignore_patterns,
token=None)

Expand Down

0 comments on commit 07d6db3

Please sign in to comment.