diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 5d8a413d91..b51856d5fc 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -30,6 +30,12 @@ ] PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' +TOKENIZER_FILES = [ + 'special_tokens_map.json', + 'tokenizer.json', + 'tokenizer.model', + 'tokenizer_config.json', +] ORAS_PASSWD_PLACEHOLDER = '' ORAS_CLI = 'oras' @@ -45,6 +51,7 @@ def download_from_hf_hub( model: str, save_dir: str, prefer_safetensors: bool = True, + tokenizers_only: bool = False, token: Optional[str] = None, ): """Downloads model files from a Hugging Face Hub model repo. @@ -57,6 +64,7 @@ def download_from_hf_hub( save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. + tokenizers_only (bool): If true, only download tokenzier files. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` environment variable. @@ -95,10 +103,13 @@ def download_from_hf_hub( ' Please make sure the repo contains either safetensors or pytorch weights.' ) + allow_patterns = TOKENIZER_FILES if tokenizers_only else None + download_start = time.time() hf_hub.snapshot_download(model, local_dir=save_dir, ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns, token=token) download_duration = time.time() - download_start log.info( @@ -221,16 +232,18 @@ def download_from_oras(model: str, config_file: str, credentials_dir: str, save_dir: str, + tokenizer_only: bool, concurrency: int = 10): """Download from an OCI-compliant registry using oras. Args: - model: The name of the model to download. - config_file: Path to a YAML config file that maps model names to registry paths. - credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three + model (str): The name of the model to download. + config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths. + credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three files: `username`, `password`, and `registry`, each of which contains the corresponding credential. - save_dir: Path to the directory where files will be downloaded. - concurrency: The number of concurrent downloads to run. + save_dir (str): Path to the directory where files will be downloaded. + tokenizer_only (bool): If true, only download the tokenzier files. + concurrency (int): The number of concurrent downloads to run. """ if shutil.which(ORAS_CLI) is None: raise Exception( @@ -253,7 +266,8 @@ def _read_secrets_file(secret_file_path: str,): with open(config_file, 'r', encoding='utf-8') as f: configs = yaml.safe_load(f.read()) - path = configs['models'][model] + config_type = 'tokenizers' if tokenizer_only else 'models' + path = configs[config_type][model] registry = secrets['registry'] def get_oras_cmd(username: Optional[str] = None, diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py index 1913267e20..90faff64e8 100644 --- a/scripts/misc/download_model.py +++ b/scripts/misc/download_model.py @@ -56,6 +56,11 @@ def parse_args() -> argparse.Namespace: base_parser = argparse.ArgumentParser(add_help=False) base_parser.add_argument('--save-dir', type=str, required=True) + base_parser.add_argument('--tokenizer-only', + type=bool, + required=False, + default=False, + action='store_true') # Add subparser for downloading from Hugging Face Hub. hf_parser = subparsers.add_parser('hf', parents=[base_parser]) @@ -85,6 +90,9 @@ def parse_args() -> argparse.Namespace: download_from = args.download_from if download_from == 'http': + if args.tokenizer_only == True: + raise ValueError( + 'tokenizer-only is not currently supported for http.') try: download_from_http_fileserver(args.url, args.save_dir, args.ignore_cert)