From 907fd772174a6ab781d259eddf8d179395243234 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 10 Jan 2024 18:12:36 -0800 Subject: [PATCH 1/4] wip --- scripts/misc/download_hf_model.py | 146 +++++++++++++++++++++++------- 1 file changed, 112 insertions(+), 34 deletions(-) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py index 58c3445e7d..dcd66eb7ff 100644 --- a/scripts/misc/download_hf_model.py +++ b/scripts/misc/download_hf_model.py @@ -5,10 +5,14 @@ import argparse import logging import os -import sys + +import yaml +import subprocess from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +import shutil + from llmfoundry.utils.model_download_utils import (download_from_cache_server, download_from_hf_hub) @@ -17,35 +21,104 @@ logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', level=logging.INFO) log = logging.getLogger(__name__) +ORAS_PASSWD_PLACEHOLDER = '' +ORAS_CLI = 'oras' -if __name__ == '__main__': - argparser = argparse.ArgumentParser() - argparser.add_argument('--model', type=str, required=True) - argparser.add_argument('--download-from', +def download_from_oras(model: str, save_dir: str, url: str, credentials_dirpath: str, model_to_path_map_file: str, concurrency: int): + """ Download from an OCI-compliant registry using oras.""" + if shutil.which(ORAS_CLI) is None: + raise Exception(f'oras cli command `${ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ') + def validate_and_add_from_secret_file(secrets: str, secret_name: str, secret_file_path: str): + try: + with open(secret_file_path, encoding="utf-8") as f: + secrets[secret_name] = f.read() + except Exception as error: + raise ValueError(f'secret_file {secret_file_path} is failed to be read; but got error') from error + + secrets = {} + validate_and_add_from_secret_file(secrets, 'username', + os.path.join(credentials_dirpath, + 'username')) + validate_and_add_from_secret_file(secrets, 'password', + os.path.join(credentials_dirpath, + 'password')) + + with open(model_to_path_map_file, 'r', encoding='utf-8') as f: + model_to_path = yaml.safe_load(f.read()) + + path = model_to_path[model] + + def get_oras_cmd_to_run(password: str): + return [ORAS_CLI, + 'pull', + '-o', + save_dir, + '--verbose', + '--concurrency', + concurrency, + '-u', + secrets['username'], + '-p', + password, + f'{url}/{path}' + ] + + cmd_to_run = get_oras_cmd_to_run(ORAS_PASSWD_PLACEHOLDER) + log.info(f'CMD for oras cli to run: {cmd_to_run}') + cmd_to_run = get_oras_cmd_to_run(secrets['password']) + subprocess.run(cmd_to_run, check=True) + +def parse_args()-> argparse.Namespace: + parser = argparse.ArgumentParser() + + # Add shared args + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--download-from', type=str, - choices=['hf', 'cache'], + choices=['hf', 'cache', 'oras'], default='hf') - argparser.add_argument('--token', - type=str, - default=os.getenv(HF_TOKEN_ENV_VAR)) - argparser.add_argument('--save-dir', + parser.add_argument('--save-dir', type=str, default=HUGGINGFACE_HUB_CACHE) - argparser.add_argument('--cache-url', type=str, default=None) - argparser.add_argument('--ignore-cert', action='store_true', default=False) - argparser.add_argument( + + # Add HF args + parser.add_argument('--token', + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR)) + + # Add cache args + parser.add_argument('--cache-url', type=str, default=None) + parser.add_argument('--ignore-cert', action='store_true', default=False) + parser.add_argument( '--fallback', - action='store_true', - default=True, + type=str, + choices=['hf', 'oras', None], + default=None, help= - 'Whether to fallback to downloading from Hugging Face if download from cache fails', + 'Fallback target to download from if download from cache fails', ) - args = argparser.parse_args(sys.argv[1:]) + # Add oras args + parser.add_argument('--oras-url', type=str, default=None) + parser.add_argument('--credentials_dirpath', type=str, default=None) + parser.add_argument('--model-to-path-map-file', type=str, default=None) + parser.add_argument('--concurrency', type=int, default=10) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + if args.download_from != 'cache' and args.fallback is not None: + raise ValueError(f'Downloading from {args.download_from}, but fallback cannot be specified unless downloading from cache.') + if args.download_from == 'hf': download_from_hf_hub(args.model, save_dir=args.save_dir, token=args.token) + elif args.download_from == 'oras': + download_from_oras(args.model, args.save_dir, args.oras_url, args.credentials_dirpath, args.model_to_path_map_file, args.concurrency) else: try: download_from_cache_server( @@ -56,28 +129,33 @@ ignore_cert=args.ignore_cert, ) - # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. - # This shouldn't actually download any files if the cache server download was successful, but should address - # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. - log.info('Repairing Hugging Face cache symlinks') + if args.fallback == 'hf': + # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. + # This shouldn't actually download any files if the cache server download was successful, but should address + # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. + log.info('Repairing Hugging Face cache symlinks') - # Hide some noisy logs that aren't important for just the symlink repair. - old_level = logging.getLogger().level - logging.getLogger().setLevel(logging.ERROR) - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - logging.getLogger().setLevel(old_level) + # Hide some noisy logs that aren't important for just the symlink repair. + old_level = logging.getLogger().level + logging.getLogger().setLevel(logging.ERROR) + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + logging.getLogger().setLevel(old_level) except PermissionError: log.error(f'Not authorized to download {args.model}.') except Exception as e: - if args.fallback: - log.warning( - f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' - ) + log.warning( + f'Failed to download {args.model} from cache server. Falling back to {args.fallback}. Error: {e}' + ) + if args.fallback == 'oras': + # save_dir, token, model, url, username, password, and config file + download_from_oras(args.model, args.save_dir, args.oras_url, args.credentials_dirpath, args.model_to_path_map_file, args.concurrency) + elif args.fallback == 'hf': + # save_dir, token, model download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) + save_dir=args.save_dir, + token=args.token) else: raise e From 9f7a866231f3e12046c702e6ed40eefc58fd8977 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 10 Jan 2024 20:29:14 -0800 Subject: [PATCH 2/4] wip --- scripts/misc/download_hf_model.py | 116 +++++++++++++++--------------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py index dcd66eb7ff..a6c7c8bb33 100644 --- a/scripts/misc/download_hf_model.py +++ b/scripts/misc/download_hf_model.py @@ -5,14 +5,13 @@ import argparse import logging import os - -import yaml +import shutil import subprocess +from typing import Dict +import yaml from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -import shutil - from llmfoundry.utils.model_download_utils import (download_from_cache_server, download_from_hf_hub) @@ -24,68 +23,66 @@ ORAS_PASSWD_PLACEHOLDER = '' ORAS_CLI = 'oras' -def download_from_oras(model: str, save_dir: str, url: str, credentials_dirpath: str, model_to_path_map_file: str, concurrency: int): - """ Download from an OCI-compliant registry using oras.""" + +def download_from_oras(model: str, save_dir: str, credentials_dirpath: str, + config_file: str, concurrency: int): + """Download from an OCI-compliant registry using oras.""" if shutil.which(ORAS_CLI) is None: - raise Exception(f'oras cli command `${ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ') - def validate_and_add_from_secret_file(secrets: str, secret_name: str, secret_file_path: str): + raise Exception( + f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' + ) + + def validate_and_add_from_secret_file(secrets: Dict[str, str], secret_name: str, + secret_file_path: str): try: - with open(secret_file_path, encoding="utf-8") as f: + with open(secret_file_path, encoding='utf-8') as f: secrets[secret_name] = f.read() except Exception as error: - raise ValueError(f'secret_file {secret_file_path} is failed to be read; but got error') from error + raise ValueError( + f'secret_file {secret_file_path} is failed to be read; but got error' + ) from error secrets = {} - validate_and_add_from_secret_file(secrets, 'username', - os.path.join(credentials_dirpath, - 'username')) - validate_and_add_from_secret_file(secrets, 'password', - os.path.join(credentials_dirpath, - 'password')) - - with open(model_to_path_map_file, 'r', encoding='utf-8') as f: - model_to_path = yaml.safe_load(f.read()) - - path = model_to_path[model] - + validate_and_add_from_secret_file( + secrets, 'username', os.path.join(credentials_dirpath, 'username')) + validate_and_add_from_secret_file( + secrets, 'password', os.path.join(credentials_dirpath, 'password')) + + with open(config_file, 'r', encoding='utf-8') as f: + configs = yaml.safe_load(f.read()) + + path = configs['models'][model] + hostname = configs['hostname'] + def get_oras_cmd_to_run(password: str): - return [ORAS_CLI, - 'pull', - '-o', - save_dir, - '--verbose', - '--concurrency', - concurrency, - '-u', - secrets['username'], - '-p', - password, - f'{url}/{path}' - ] + return [ + ORAS_CLI, 'pull', '-o', save_dir, '--verbose', '--concurrency', + str(concurrency), '-u', secrets['username'], '-p', password, + f'{hostname}/{path}' + ] cmd_to_run = get_oras_cmd_to_run(ORAS_PASSWD_PLACEHOLDER) log.info(f'CMD for oras cli to run: {cmd_to_run}') cmd_to_run = get_oras_cmd_to_run(secrets['password']) subprocess.run(cmd_to_run, check=True) -def parse_args()-> argparse.Namespace: + +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # Add shared args parser.add_argument('--model', type=str, required=True) parser.add_argument('--download-from', - type=str, - choices=['hf', 'cache', 'oras'], - default='hf') - parser.add_argument('--save-dir', - type=str, - default=HUGGINGFACE_HUB_CACHE) - + type=str, + choices=['hf', 'cache', 'oras'], + default='hf') + parser.add_argument('--save-dir', type=str, default=HUGGINGFACE_HUB_CACHE) + # Add HF args parser.add_argument('--token', - type=str, - default=os.getenv(HF_TOKEN_ENV_VAR)) - + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR)) + # Add cache args parser.add_argument('--cache-url', type=str, default=None) parser.add_argument('--ignore-cert', action='store_true', default=False) @@ -94,14 +91,12 @@ def parse_args()-> argparse.Namespace: type=str, choices=['hf', 'oras', None], default=None, - help= - 'Fallback target to download from if download from cache fails', + help='Fallback target to download from if download from cache fails', ) # Add oras args - parser.add_argument('--oras-url', type=str, default=None) - parser.add_argument('--credentials_dirpath', type=str, default=None) - parser.add_argument('--model-to-path-map-file', type=str, default=None) + parser.add_argument('--credentials-dirpath', type=str, default=None) + parser.add_argument('--oras-config-file', type=str, default=None) parser.add_argument('--concurrency', type=int, default=10) return parser.parse_args() @@ -111,14 +106,17 @@ def parse_args()-> argparse.Namespace: args = parse_args() if args.download_from != 'cache' and args.fallback is not None: - raise ValueError(f'Downloading from {args.download_from}, but fallback cannot be specified unless downloading from cache.') + raise ValueError( + f'Downloading from {args.download_from}, but fallback cannot be specified unless downloading from cache.' + ) if args.download_from == 'hf': download_from_hf_hub(args.model, save_dir=args.save_dir, token=args.token) elif args.download_from == 'oras': - download_from_oras(args.model, args.save_dir, args.oras_url, args.credentials_dirpath, args.model_to_path_map_file, args.concurrency) + download_from_oras(args.model, args.save_dir, args.credentials_dirpath, + args.oras_config_file, args.concurrency) else: try: download_from_cache_server( @@ -139,8 +137,8 @@ def parse_args()-> argparse.Namespace: old_level = logging.getLogger().level logging.getLogger().setLevel(logging.ERROR) download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) + save_dir=args.save_dir, + token=args.token) logging.getLogger().setLevel(old_level) except PermissionError: @@ -151,11 +149,13 @@ def parse_args()-> argparse.Namespace: ) if args.fallback == 'oras': # save_dir, token, model, url, username, password, and config file - download_from_oras(args.model, args.save_dir, args.oras_url, args.credentials_dirpath, args.model_to_path_map_file, args.concurrency) + download_from_oras(args.model, args.save_dir, + args.credentials_dirpath, + args.oras_config_file, args.concurrency) elif args.fallback == 'hf': # save_dir, token, model download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) + save_dir=args.save_dir, + token=args.token) else: raise e From a163283d69525da831bdba9edb7fdc989cb01742 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Thu, 11 Jan 2024 14:49:17 -0800 Subject: [PATCH 3/4] Accept registry file for hostname --- scripts/misc/download_hf_model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py index a6c7c8bb33..42e20d79da 100644 --- a/scripts/misc/download_hf_model.py +++ b/scripts/misc/download_hf_model.py @@ -7,7 +7,6 @@ import os import shutil import subprocess -from typing import Dict import yaml from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -32,8 +31,11 @@ def download_from_oras(model: str, save_dir: str, credentials_dirpath: str, f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' ) - def validate_and_add_from_secret_file(secrets: Dict[str, str], secret_name: str, - secret_file_path: str): + def validate_and_add_from_secret_file( + secrets: dict[str, str], + secret_name: str, + secret_file_path: str, + ): try: with open(secret_file_path, encoding='utf-8') as f: secrets[secret_name] = f.read() @@ -47,12 +49,14 @@ def validate_and_add_from_secret_file(secrets: Dict[str, str], secret_name: str, secrets, 'username', os.path.join(credentials_dirpath, 'username')) validate_and_add_from_secret_file( secrets, 'password', os.path.join(credentials_dirpath, 'password')) + validate_and_add_from_secret_file( + secrets, 'registry', os.path.join(credentials_dirpath, 'registry')) with open(config_file, 'r', encoding='utf-8') as f: configs = yaml.safe_load(f.read()) - path = configs['models'][model] - hostname = configs['hostname'] + path = configs[model] + hostname = secrets['registry'] def get_oras_cmd_to_run(password: str): return [ From 0e46f948c2da720d416cae02bd21294ad725086f Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Thu, 11 Jan 2024 15:00:26 -0800 Subject: [PATCH 4/4] Make sure no sensitive info is surfaced in subprocess error --- scripts/misc/download_hf_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py index 42e20d79da..f0614fe7d3 100644 --- a/scripts/misc/download_hf_model.py +++ b/scripts/misc/download_hf_model.py @@ -65,10 +65,14 @@ def get_oras_cmd_to_run(password: str): f'{hostname}/{path}' ] - cmd_to_run = get_oras_cmd_to_run(ORAS_PASSWD_PLACEHOLDER) - log.info(f'CMD for oras cli to run: {cmd_to_run}') + cmd_to_run_no_password = get_oras_cmd_to_run(ORAS_PASSWD_PLACEHOLDER) + log.info(f'CMD for oras cli to run: {cmd_to_run_no_password}') cmd_to_run = get_oras_cmd_to_run(secrets['password']) - subprocess.run(cmd_to_run, check=True) + try: + subprocess.run(cmd_to_run, check=True) + except subprocess.CalledProcessError as e: + # Intercept the error and replace the cmd, which may have sensitive info. + raise subprocess.CalledProcessError(e.returncode, cmd_to_run_no_password, e.output, e.stderr) def parse_args() -> argparse.Namespace: