Skip to content

Commit

Permalink
Add flag to only download tokenizers from HF or oras
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jan 23, 2024
1 parent b2a0c03 commit 936d5e5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
26 changes: 20 additions & 6 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
]
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
TOKENIZER_FILES = [
'special_tokens_map.json',
'tokenizer.json',
'tokenizer.model',
'tokenizer_config.json',
]

ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
ORAS_CLI = 'oras'
Expand All @@ -45,6 +51,7 @@ def download_from_hf_hub(
model: str,
save_dir: str,
prefer_safetensors: bool = True,
tokenizers_only: bool = False,
token: Optional[str] = None,
):
"""Downloads model files from a Hugging Face Hub model repo.
Expand All @@ -57,6 +64,7 @@ def download_from_hf_hub(
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
tokenizers_only (bool): If true, only download tokenzier files.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.
Expand Down Expand Up @@ -95,10 +103,13 @@ def download_from_hf_hub(
' Please make sure the repo contains either safetensors or pytorch weights.'
)

allow_patterns = TOKENIZER_FILES if tokenizers_only else None

download_start = time.time()
hf_hub.snapshot_download(model,
local_dir=save_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
token=token)
download_duration = time.time() - download_start
log.info(
Expand Down Expand Up @@ -221,16 +232,18 @@ def download_from_oras(model: str,
config_file: str,
credentials_dir: str,
save_dir: str,
tokenizer_only: bool,
concurrency: int = 10):
"""Download from an OCI-compliant registry using oras.
Args:
model: The name of the model to download.
config_file: Path to a YAML config file that maps model names to registry paths.
credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three
model (str): The name of the model to download.
config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths.
credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
save_dir: Path to the directory where files will be downloaded.
concurrency: The number of concurrent downloads to run.
save_dir (str): Path to the directory where files will be downloaded.
tokenizer_only (bool): If true, only download the tokenzier files.
concurrency (int): The number of concurrent downloads to run.
"""
if shutil.which(ORAS_CLI) is None:
raise Exception(
Expand All @@ -253,7 +266,8 @@ def _read_secrets_file(secret_file_path: str,):
with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())

path = configs['models'][model]
config_type = 'tokenizers' if tokenizer_only else 'models'
path = configs[config_type][model]
registry = secrets['registry']

def get_oras_cmd(username: Optional[str] = None,
Expand Down
8 changes: 8 additions & 0 deletions scripts/misc/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def parse_args() -> argparse.Namespace:

base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument('--save-dir', type=str, required=True)
base_parser.add_argument('--tokenizer-only',
type=bool,
required=False,
default=False,
action='store_true')

# Add subparser for downloading from Hugging Face Hub.
hf_parser = subparsers.add_parser('hf', parents=[base_parser])
Expand Down Expand Up @@ -85,6 +90,9 @@ def parse_args() -> argparse.Namespace:
download_from = args.download_from

if download_from == 'http':
if args.tokenizer_only == True:
raise ValueError(
'tokenizer-only is not currently supported for http.')
try:
download_from_http_fileserver(args.url, args.save_dir,
args.ignore_cert)
Expand Down

0 comments on commit 936d5e5

Please sign in to comment.