Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model downloading from oras #857

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 123 additions & 37 deletions scripts/misc/download_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import argparse
import logging
import os
import sys
import shutil
import subprocess

import yaml
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE

from llmfoundry.utils.model_download_utils import (download_from_cache_server,
Expand All @@ -17,35 +19,112 @@
logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s',
level=logging.INFO)
log = logging.getLogger(__name__)
ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
ORAS_CLI = 'oras'

if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--model', type=str, required=True)
argparser.add_argument('--download-from',
type=str,
choices=['hf', 'cache'],
default='hf')
argparser.add_argument('--token',
type=str,
default=os.getenv(HF_TOKEN_ENV_VAR))
argparser.add_argument('--save-dir',
type=str,
default=HUGGINGFACE_HUB_CACHE)
argparser.add_argument('--cache-url', type=str, default=None)
argparser.add_argument('--ignore-cert', action='store_true', default=False)
argparser.add_argument(

def download_from_oras(model: str, save_dir: str, credentials_dirpath: str,
config_file: str, concurrency: int):
"""Download from an OCI-compliant registry using oras."""
if shutil.which(ORAS_CLI) is None:
raise Exception(
f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation '
)

def validate_and_add_from_secret_file(
secrets: dict[str, str],
secret_name: str,
secret_file_path: str,
):
try:
with open(secret_file_path, encoding='utf-8') as f:
secrets[secret_name] = f.read()
except Exception as error:
raise ValueError(
f'secret_file {secret_file_path} is failed to be read; but got error'
) from error

secrets = {}
validate_and_add_from_secret_file(
secrets, 'username', os.path.join(credentials_dirpath, 'username'))
validate_and_add_from_secret_file(
secrets, 'password', os.path.join(credentials_dirpath, 'password'))
validate_and_add_from_secret_file(
secrets, 'registry', os.path.join(credentials_dirpath, 'registry'))

with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())

path = configs[model]
hostname = secrets['registry']

def get_oras_cmd_to_run(password: str):
return [
ORAS_CLI, 'pull', '-o', save_dir, '--verbose', '--concurrency',
str(concurrency), '-u', secrets['username'], '-p', password,
f'{hostname}/{path}'
]

cmd_to_run_no_password = get_oras_cmd_to_run(ORAS_PASSWD_PLACEHOLDER)
log.info(f'CMD for oras cli to run: {cmd_to_run_no_password}')

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High

This expression logs
sensitive data (password)
as clear text.
This expression logs
sensitive data (password)
as clear text.
This expression logs
sensitive data (password)
as clear text.
cmd_to_run = get_oras_cmd_to_run(secrets['password'])
try:
subprocess.run(cmd_to_run, check=True)
except subprocess.CalledProcessError as e:
# Intercept the error and replace the cmd, which may have sensitive info.
raise subprocess.CalledProcessError(e.returncode, cmd_to_run_no_password, e.output, e.stderr)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()

# Add shared args
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--download-from',
type=str,
choices=['hf', 'cache', 'oras'],
default='hf')
parser.add_argument('--save-dir', type=str, default=HUGGINGFACE_HUB_CACHE)

# Add HF args
parser.add_argument('--token',
type=str,
default=os.getenv(HF_TOKEN_ENV_VAR))

# Add cache args
parser.add_argument('--cache-url', type=str, default=None)
parser.add_argument('--ignore-cert', action='store_true', default=False)
parser.add_argument(
'--fallback',
action='store_true',
default=True,
help=
'Whether to fallback to downloading from Hugging Face if download from cache fails',
type=str,
choices=['hf', 'oras', None],
default=None,
help='Fallback target to download from if download from cache fails',
)

args = argparser.parse_args(sys.argv[1:])
# Add oras args
parser.add_argument('--credentials-dirpath', type=str, default=None)
parser.add_argument('--oras-config-file', type=str, default=None)
parser.add_argument('--concurrency', type=int, default=10)

return parser.parse_args()


if __name__ == '__main__':
args = parse_args()

if args.download_from != 'cache' and args.fallback is not None:
raise ValueError(
f'Downloading from {args.download_from}, but fallback cannot be specified unless downloading from cache.'
)

if args.download_from == 'hf':
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token)
elif args.download_from == 'oras':
download_from_oras(args.model, args.save_dir, args.credentials_dirpath,
args.oras_config_file, args.concurrency)
else:
try:
download_from_cache_server(
Expand All @@ -56,26 +135,33 @@
ignore_cert=args.ignore_cert,
)

# A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure.
# This shouldn't actually download any files if the cache server download was successful, but should address
# a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized.
log.info('Repairing Hugging Face cache symlinks')
if args.fallback == 'hf':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerrychen109 is it okay to condition on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we condition on the save_dir being the HF cache? symlinks are just for the HF cache, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.fallback == 'hf':
if args.fallback == 'hf' and args.save_dir == HUGGINGFACE_HUB_CACHE:

# A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure.
# This shouldn't actually download any files if the cache server download was successful, but should address
# a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized.
log.info('Repairing Hugging Face cache symlinks')

# Hide some noisy logs that aren't important for just the symlink repair.
old_level = logging.getLogger().level
logging.getLogger().setLevel(logging.ERROR)
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token)
logging.getLogger().setLevel(old_level)
# Hide some noisy logs that aren't important for just the symlink repair.
old_level = logging.getLogger().level
logging.getLogger().setLevel(logging.ERROR)
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token)
logging.getLogger().setLevel(old_level)

except PermissionError:
log.error(f'Not authorized to download {args.model}.')
except Exception as e:
if args.fallback:
log.warning(
f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}'
)
log.warning(
f'Failed to download {args.model} from cache server. Falling back to {args.fallback}. Error: {e}'
)
if args.fallback == 'oras':
# save_dir, token, model, url, username, password, and config file
download_from_oras(args.model, args.save_dir,
args.credentials_dirpath,
args.oras_config_file, args.concurrency)
elif args.fallback == 'hf':
# save_dir, token, model
download_from_hf_hub(args.model,
save_dir=args.save_dir,
token=args.token)
Expand Down
Loading