-
Notifications
You must be signed in to change notification settings - Fork 536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model downloading from oras #857
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,8 +5,10 @@ | |||||
import argparse | ||||||
import logging | ||||||
import os | ||||||
import sys | ||||||
import shutil | ||||||
import subprocess | ||||||
|
||||||
import yaml | ||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | ||||||
|
||||||
from llmfoundry.utils.model_download_utils import (download_from_cache_server, | ||||||
|
@@ -17,35 +19,112 @@ | |||||
logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', | ||||||
level=logging.INFO) | ||||||
log = logging.getLogger(__name__) | ||||||
ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>' | ||||||
ORAS_CLI = 'oras' | ||||||
|
||||||
if __name__ == '__main__': | ||||||
argparser = argparse.ArgumentParser() | ||||||
argparser.add_argument('--model', type=str, required=True) | ||||||
argparser.add_argument('--download-from', | ||||||
type=str, | ||||||
choices=['hf', 'cache'], | ||||||
default='hf') | ||||||
argparser.add_argument('--token', | ||||||
type=str, | ||||||
default=os.getenv(HF_TOKEN_ENV_VAR)) | ||||||
argparser.add_argument('--save-dir', | ||||||
type=str, | ||||||
default=HUGGINGFACE_HUB_CACHE) | ||||||
argparser.add_argument('--cache-url', type=str, default=None) | ||||||
argparser.add_argument('--ignore-cert', action='store_true', default=False) | ||||||
argparser.add_argument( | ||||||
|
||||||
def download_from_oras(model: str, save_dir: str, credentials_dirpath: str, | ||||||
config_file: str, concurrency: int): | ||||||
"""Download from an OCI-compliant registry using oras.""" | ||||||
if shutil.which(ORAS_CLI) is None: | ||||||
raise Exception( | ||||||
f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' | ||||||
) | ||||||
|
||||||
def validate_and_add_from_secret_file( | ||||||
secrets: dict[str, str], | ||||||
secret_name: str, | ||||||
secret_file_path: str, | ||||||
): | ||||||
try: | ||||||
with open(secret_file_path, encoding='utf-8') as f: | ||||||
secrets[secret_name] = f.read() | ||||||
except Exception as error: | ||||||
raise ValueError( | ||||||
f'secret_file {secret_file_path} is failed to be read; but got error' | ||||||
) from error | ||||||
|
||||||
secrets = {} | ||||||
validate_and_add_from_secret_file( | ||||||
secrets, 'username', os.path.join(credentials_dirpath, 'username')) | ||||||
validate_and_add_from_secret_file( | ||||||
secrets, 'password', os.path.join(credentials_dirpath, 'password')) | ||||||
validate_and_add_from_secret_file( | ||||||
secrets, 'registry', os.path.join(credentials_dirpath, 'registry')) | ||||||
|
||||||
with open(config_file, 'r', encoding='utf-8') as f: | ||||||
configs = yaml.safe_load(f.read()) | ||||||
|
||||||
path = configs[model] | ||||||
hostname = secrets['registry'] | ||||||
|
||||||
def get_oras_cmd_to_run(password: str): | ||||||
return [ | ||||||
ORAS_CLI, 'pull', '-o', save_dir, '--verbose', '--concurrency', | ||||||
str(concurrency), '-u', secrets['username'], '-p', password, | ||||||
f'{hostname}/{path}' | ||||||
] | ||||||
|
||||||
cmd_to_run_no_password = get_oras_cmd_to_run(ORAS_PASSWD_PLACEHOLDER) | ||||||
log.info(f'CMD for oras cli to run: {cmd_to_run_no_password}') | ||||||
cmd_to_run = get_oras_cmd_to_run(secrets['password']) | ||||||
try: | ||||||
subprocess.run(cmd_to_run, check=True) | ||||||
except subprocess.CalledProcessError as e: | ||||||
# Intercept the error and replace the cmd, which may have sensitive info. | ||||||
raise subprocess.CalledProcessError(e.returncode, cmd_to_run_no_password, e.output, e.stderr) | ||||||
|
||||||
|
||||||
def parse_args() -> argparse.Namespace: | ||||||
parser = argparse.ArgumentParser() | ||||||
|
||||||
# Add shared args | ||||||
parser.add_argument('--model', type=str, required=True) | ||||||
parser.add_argument('--download-from', | ||||||
type=str, | ||||||
choices=['hf', 'cache', 'oras'], | ||||||
default='hf') | ||||||
parser.add_argument('--save-dir', type=str, default=HUGGINGFACE_HUB_CACHE) | ||||||
|
||||||
# Add HF args | ||||||
parser.add_argument('--token', | ||||||
type=str, | ||||||
default=os.getenv(HF_TOKEN_ENV_VAR)) | ||||||
|
||||||
# Add cache args | ||||||
parser.add_argument('--cache-url', type=str, default=None) | ||||||
parser.add_argument('--ignore-cert', action='store_true', default=False) | ||||||
parser.add_argument( | ||||||
'--fallback', | ||||||
action='store_true', | ||||||
default=True, | ||||||
help= | ||||||
'Whether to fallback to downloading from Hugging Face if download from cache fails', | ||||||
type=str, | ||||||
choices=['hf', 'oras', None], | ||||||
default=None, | ||||||
help='Fallback target to download from if download from cache fails', | ||||||
) | ||||||
|
||||||
args = argparser.parse_args(sys.argv[1:]) | ||||||
# Add oras args | ||||||
parser.add_argument('--credentials-dirpath', type=str, default=None) | ||||||
parser.add_argument('--oras-config-file', type=str, default=None) | ||||||
parser.add_argument('--concurrency', type=int, default=10) | ||||||
|
||||||
return parser.parse_args() | ||||||
|
||||||
|
||||||
if __name__ == '__main__': | ||||||
args = parse_args() | ||||||
|
||||||
if args.download_from != 'cache' and args.fallback is not None: | ||||||
raise ValueError( | ||||||
f'Downloading from {args.download_from}, but fallback cannot be specified unless downloading from cache.' | ||||||
) | ||||||
|
||||||
if args.download_from == 'hf': | ||||||
download_from_hf_hub(args.model, | ||||||
save_dir=args.save_dir, | ||||||
token=args.token) | ||||||
elif args.download_from == 'oras': | ||||||
download_from_oras(args.model, args.save_dir, args.credentials_dirpath, | ||||||
args.oras_config_file, args.concurrency) | ||||||
else: | ||||||
try: | ||||||
download_from_cache_server( | ||||||
|
@@ -56,26 +135,33 @@ | |||||
ignore_cert=args.ignore_cert, | ||||||
) | ||||||
|
||||||
# A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. | ||||||
# This shouldn't actually download any files if the cache server download was successful, but should address | ||||||
# a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. | ||||||
log.info('Repairing Hugging Face cache symlinks') | ||||||
if args.fallback == 'hf': | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jerrychen109 is it okay to condition on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we condition on the save_dir being the HF cache? symlinks are just for the HF cache, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. | ||||||
# This shouldn't actually download any files if the cache server download was successful, but should address | ||||||
# a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. | ||||||
log.info('Repairing Hugging Face cache symlinks') | ||||||
|
||||||
# Hide some noisy logs that aren't important for just the symlink repair. | ||||||
old_level = logging.getLogger().level | ||||||
logging.getLogger().setLevel(logging.ERROR) | ||||||
download_from_hf_hub(args.model, | ||||||
save_dir=args.save_dir, | ||||||
token=args.token) | ||||||
logging.getLogger().setLevel(old_level) | ||||||
# Hide some noisy logs that aren't important for just the symlink repair. | ||||||
old_level = logging.getLogger().level | ||||||
logging.getLogger().setLevel(logging.ERROR) | ||||||
download_from_hf_hub(args.model, | ||||||
save_dir=args.save_dir, | ||||||
token=args.token) | ||||||
logging.getLogger().setLevel(old_level) | ||||||
|
||||||
except PermissionError: | ||||||
log.error(f'Not authorized to download {args.model}.') | ||||||
except Exception as e: | ||||||
if args.fallback: | ||||||
log.warning( | ||||||
f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' | ||||||
) | ||||||
log.warning( | ||||||
f'Failed to download {args.model} from cache server. Falling back to {args.fallback}. Error: {e}' | ||||||
) | ||||||
if args.fallback == 'oras': | ||||||
# save_dir, token, model, url, username, password, and config file | ||||||
download_from_oras(args.model, args.save_dir, | ||||||
args.credentials_dirpath, | ||||||
args.oras_config_file, args.concurrency) | ||||||
elif args.fallback == 'hf': | ||||||
# save_dir, token, model | ||||||
download_from_hf_hub(args.model, | ||||||
save_dir=args.save_dir, | ||||||
token=args.token) | ||||||
|
Check failure
Code scanning / CodeQL
Clear-text logging of sensitive information High