diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 7abe4dcf75..83af2153a2 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -12,7 +12,7 @@ log_config, pop_config, update_batch_size_info) from llmfoundry.utils.model_download_utils import ( - download_from_cache_server, download_from_hf_hub) + download_from_hf_hub, download_from_http_fileserver) except ImportError as e: raise ImportError( 'Please make sure to pip install . to get requirements for llm-foundry.' @@ -28,7 +28,7 @@ 'build_tokenizer', 'calculate_batch_size_info', 'convert_and_save_ft_weights', - 'download_from_cache_server', + 'download_from_http_fileserver', 'download_from_hf_hub', 'get_hf_tokenizer_from_composer_state_dict', 'update_batch_size_info', diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 2104455e0f..5d8a413d91 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -5,6 +5,8 @@ import copy import logging import os +import shutil +import subprocess import time import warnings from http import HTTPStatus @@ -14,6 +16,7 @@ import huggingface_hub as hf_hub import requests import tenacity +import yaml from bs4 import BeautifulSoup from requests.packages.urllib3.exceptions import InsecureRequestWarning from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME @@ -28,6 +31,9 @@ PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' +ORAS_PASSWD_PLACEHOLDER = '' +ORAS_CLI = 'oras' + log = logging.getLogger(__name__) @@ -36,8 +42,8 @@ stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10)) def download_from_hf_hub( - repo_id: str, - save_dir: Optional[str] = None, + model: str, + save_dir: str, prefer_safetensors: bool = True, token: Optional[str] = None, ): @@ -48,8 +54,7 @@ def download_from_hf_hub( Args: repo_id (str): The Hugging Face Hub repo ID. - save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads - from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory. + save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the @@ -59,7 +64,7 @@ def download_from_hf_hub( RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized. ValueError: If the model repo doesn't contain any supported model weights. """ - repo_files = set(hf_hub.list_repo_files(repo_id)) + repo_files = set(hf_hub.list_repo_files(model)) # Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer. ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS) @@ -86,18 +91,18 @@ def download_from_hf_hub( log.info('Only pytorch available. Ignoring weights preference.') else: raise ValueError( - f'No supported model weights found in repo {repo_id}.' + + f'No supported model weights found in repo {model}.' + ' Please make sure the repo contains either safetensors or pytorch weights.' ) download_start = time.time() - hf_hub.snapshot_download(repo_id, - cache_dir=save_dir, + hf_hub.snapshot_download(model, + local_dir=save_dir, ignore_patterns=ignore_patterns, token=token) download_duration = time.time() - download_start log.info( - f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds' + f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds' ) @@ -140,6 +145,7 @@ def _recursive_download( RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized. """ url = urljoin(base_url, path) + print(url) response = session.get(url, verify=(not ignore_cert)) if response.status_code == HTTPStatus.UNAUTHORIZED: @@ -156,7 +162,7 @@ def _recursive_download( ) # Assume that the URL points to a file if it does not end with a slash. - if not path.endswith('/'): + if not url.endswith('/'): save_path = os.path.join(save_dir, path) parent_dir = os.path.dirname(save_path) if not os.path.exists(parent_dir): @@ -171,6 +177,7 @@ def _recursive_download( # If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links # to download. child_links = _extract_links_from_html(response.content.decode()) + print(child_links) for child_link in child_links: _recursive_download(session, base_url, @@ -183,53 +190,98 @@ def _recursive_download( (PermissionError, ValueError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10)) -def download_from_cache_server( - model_name: str, - cache_base_url: str, +def download_from_http_fileserver( + url: str, save_dir: str, - token: Optional[str] = None, ignore_cert: bool = False, ): - """Downloads Hugging Face models from a mirror file server. - - The file server is expected to store the files in the same structure as the Hugging Face cache - structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache. + """Downloads files from a remote HTTP file server. Args: - model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face - Hub. - cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob - files from `//blobs/`, where `formatted_model_name` is equal to - `models/` with all slashes replaced with `--`. - save_dir: The directory to save the downloaded files to. - token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` - environment variable. - ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to - False. + url (str): The base URL where the files are located. + save_dir (str): The directory to save downloaded files to. + ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server. + Defaults to False. WARNING: Setting this to true is *not* secure, as no certificate verification will be performed. """ - formatted_model_name = f'models/{model_name}'.replace('/', '--') with requests.Session() as session: - session.headers.update({'Authorization': f'Bearer {token}'}) - - download_start = time.time() - # Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True with warnings.catch_warnings(): if ignore_cert: warnings.simplefilter('ignore', category=InsecureRequestWarning) - # Only downloads the blobs in order to avoid downloading model files twice due to the - # symlnks in the Hugging Face cache structure: - _recursive_download( - session, - cache_base_url, - # Trailing slash to indicate directory - f'{formatted_model_name}/blobs/', - save_dir, - ignore_cert=ignore_cert, - ) - download_duration = time.time() - download_start - log.info( - f'Downloaded model {model_name} from cache server in {download_duration} seconds' + _recursive_download(session, + url, + '', + save_dir, + ignore_cert=ignore_cert) + + +def download_from_oras(model: str, + config_file: str, + credentials_dir: str, + save_dir: str, + concurrency: int = 10): + """Download from an OCI-compliant registry using oras. + + Args: + model: The name of the model to download. + config_file: Path to a YAML config file that maps model names to registry paths. + credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three + files: `username`, `password`, and `registry`, each of which contains the corresponding credential. + save_dir: Path to the directory where files will be downloaded. + concurrency: The number of concurrent downloads to run. + """ + if shutil.which(ORAS_CLI) is None: + raise Exception( + f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' ) + + def _read_secrets_file(secret_file_path: str,): + try: + with open(secret_file_path, encoding='utf-8') as f: + return f.read().strip() + except Exception as error: + raise ValueError( + f'secrets file {secret_file_path} failed to be read') from error + + secrets = {} + for secret in ['username', 'password', 'registry']: + secrets[secret] = _read_secrets_file( + os.path.join(credentials_dir, secret)) + + with open(config_file, 'r', encoding='utf-8') as f: + configs = yaml.safe_load(f.read()) + + path = configs['models'][model] + registry = secrets['registry'] + + def get_oras_cmd(username: Optional[str] = None, + password: Optional[str] = None): + cmd = [ + ORAS_CLI, + 'pull', + f'{registry}/{path}', + '-o', + save_dir, + '--verbose', + '--concurrency', + str(concurrency), + ] + if username is not None: + cmd.extend(['--username', username]) + if password is not None: + cmd.extend(['--password', password]) + + return cmd + + cmd_without_creds = get_oras_cmd() + log.info(f'CMD for oras cli to run: {" ".join(cmd_without_creds)}') + cmd_to_run = get_oras_cmd(username=secrets['username'], + password=secrets['password']) + try: + subprocess.run(cmd_to_run, check=True) + except subprocess.CalledProcessError as e: + # Intercept the error and replace the cmd, which may have sensitive info. + raise subprocess.CalledProcessError(e.returncode, cmd_without_creds, + e.output, e.stderr) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py deleted file mode 100644 index 58c3445e7d..0000000000 --- a/scripts/misc/download_hf_model.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -"""Script to download model weights from Hugging Face Hub or a cache server.""" -import argparse -import logging -import os -import sys - -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE - -from llmfoundry.utils.model_download_utils import (download_from_cache_server, - download_from_hf_hub) - -HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' - -logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', - level=logging.INFO) -log = logging.getLogger(__name__) - -if __name__ == '__main__': - argparser = argparse.ArgumentParser() - argparser.add_argument('--model', type=str, required=True) - argparser.add_argument('--download-from', - type=str, - choices=['hf', 'cache'], - default='hf') - argparser.add_argument('--token', - type=str, - default=os.getenv(HF_TOKEN_ENV_VAR)) - argparser.add_argument('--save-dir', - type=str, - default=HUGGINGFACE_HUB_CACHE) - argparser.add_argument('--cache-url', type=str, default=None) - argparser.add_argument('--ignore-cert', action='store_true', default=False) - argparser.add_argument( - '--fallback', - action='store_true', - default=True, - help= - 'Whether to fallback to downloading from Hugging Face if download from cache fails', - ) - - args = argparser.parse_args(sys.argv[1:]) - if args.download_from == 'hf': - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - else: - try: - download_from_cache_server( - args.model, - args.cache_url, - args.save_dir, - token=args.token, - ignore_cert=args.ignore_cert, - ) - - # A little hacky: run the Hugging Face download just to repair the symlinks in the HF cache file structure. - # This shouldn't actually download any files if the cache server download was successful, but should address - # a non-deterministic bug where the symlinks aren't repaired properly by the time the model is initialized. - log.info('Repairing Hugging Face cache symlinks') - - # Hide some noisy logs that aren't important for just the symlink repair. - old_level = logging.getLogger().level - logging.getLogger().setLevel(logging.ERROR) - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - logging.getLogger().setLevel(old_level) - - except PermissionError: - log.error(f'Not authorized to download {args.model}.') - except Exception as e: - if args.fallback: - log.warning( - f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' - ) - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token) - else: - raise e diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py new file mode 100644 index 0000000000..1913267e20 --- /dev/null +++ b/scripts/misc/download_model.py @@ -0,0 +1,115 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to download model weights from Hugging Face Hub or a cache server. + +Download from Hugging Face Hub: + python download_model.py hf --model mosaicml/mpt-7b --save-dir --token + +Download from ORAS registry: + python download_model.py oras --registry --path mosaicml/mpt-7b --save-dir + +Download from an HTTP file server: + python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir + +Download from an HTTP file server with fallback to Hugging Face Hub: + python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir \ + fallback-hf --model mosaicml/mpt-7b --token hf_token +""" +import argparse +import logging +import os + +from llmfoundry.utils.model_download_utils import ( + download_from_hf_hub, download_from_http_fileserver, download_from_oras) + +HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' + +logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', + level=logging.INFO) +log = logging.getLogger(__name__) + + +def add_hf_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--prefer-safetensors', type=bool, default=True) + parser.add_argument('--token', + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR)) + + +def add_oras_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--config-file', type=str, required=True) + parser.add_argument('--credentials-dir', type=str, required=True) + parser.add_argument('--concurrency', type=int, default=10) + + +def add_http_parser_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--url', type=str, required=True) + parser.add_argument('--ignore-cert', action='store_true', default=False) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest='download_from', required=True) + + base_parser = argparse.ArgumentParser(add_help=False) + base_parser.add_argument('--save-dir', type=str, required=True) + + # Add subparser for downloading from Hugging Face Hub. + hf_parser = subparsers.add_parser('hf', parents=[base_parser]) + add_hf_parser_arguments(hf_parser) + + # Add subparser for downloading from ORAS registry. + oras_parser = subparsers.add_parser('oras', parents=[base_parser]) + add_oras_parser_arguments(oras_parser) + + # Add subparser for downloading from an HTTP file server. + http_parser = subparsers.add_parser('http', parents=[base_parser]) + add_http_parser_arguments(http_parser) + + # Add fallbacks for HTTP + fallback_subparsers = http_parser.add_subparsers(dest='fallback') + hf_fallback_parser = fallback_subparsers.add_parser('fallback-hf') + add_hf_parser_arguments(hf_fallback_parser) + + oras_fallback_parser = fallback_subparsers.add_parser('fallback-oras') + add_oras_parser_arguments(oras_fallback_parser) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + download_from = args.download_from + + if download_from == 'http': + try: + download_from_http_fileserver(args.url, args.save_dir, + args.ignore_cert) + except PermissionError as e: + log.error(f'Not authorized to download {args.model}.') + raise e + except Exception as e: + log.warning(f'Failed to download from HTTP server with error: {e}') + if args.fallback: + log.warning(f'Falling back to provided fallback destination.') + if args.fallback == 'fallback-hf': + download_from = 'hf' + elif args.fallback == 'fallback-oras': + download_from = 'oras' + else: + raise ValueError( + f'Invalid fallback destination {args.fallback}.') + else: + raise e + + if download_from == 'hf': + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token, + prefer_safetensors=args.prefer_safetensors) + elif download_from == 'oras': + download_from_oras(args.model, args.config_file, args.credentials_dir, + args.save_dir, args.concurrency) diff --git a/tests/utils/test_model_download_utils.py b/tests/utils/test_model_download_utils.py index 27b9805cda..471a39dcdb 100644 --- a/tests/utils/test_model_download_utils.py +++ b/tests/utils/test_model_download_utils.py @@ -16,11 +16,9 @@ from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME -from llmfoundry.utils.model_download_utils import (DEFAULT_IGNORE_PATTERNS, - PYTORCH_WEIGHTS_PATTERN, - SAFE_WEIGHTS_PATTERN, - download_from_cache_server, - download_from_hf_hub) +from llmfoundry.utils.model_download_utils import ( + DEFAULT_IGNORE_PATTERNS, PYTORCH_WEIGHTS_PATTERN, SAFE_WEIGHTS_PATTERN, + download_from_hf_hub, download_from_http_fileserver) # ======================== download_from_hf_hub tests ======================== @@ -103,15 +101,17 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, repo_files: List[str], expected_ignore_patterns: List[str]): test_repo_id = 'test_repo_id' + save_dir = 'save_dir' mock_list_repo_files.return_value = repo_files - download_from_hf_hub(test_repo_id, prefer_safetensors=prefer_safetensors) + download_from_hf_hub(test_repo_id, + save_dir=save_dir, + prefer_safetensors=prefer_safetensors) mock_snapshot_download.assert_called_once_with( test_repo_id, - cache_dir=None, + local_dir=save_dir, ignore_patterns=expected_ignore_patterns, - token=None, - ) + token=None) @mock.patch('huggingface_hub.snapshot_download') @@ -121,10 +121,11 @@ def test_download_from_hf_hub_no_weights( mock_snapshot_download: MagicMock, ): test_repo_id = 'test_repo_id' + save_dir = 'save_dir' mock_list_repo_files.return_value = [] with pytest.raises(ValueError): - download_from_hf_hub(test_repo_id) + download_from_hf_hub(test_repo_id, save_dir) mock_snapshot_download.assert_not_called() @@ -148,12 +149,12 @@ def test_download_from_hf_hub_retry( mock_snapshot_download.side_effect = exception with pytest.raises((tenacity.RetryError, exception.__class__)): - download_from_hf_hub('test_repo_id') + download_from_hf_hub('test_repo_id', 'save_dir') assert mock_snapshot_download.call_count == expected_attempts -# ======================== download_from_cache_server tests ======================== +# ======================== download_from_http_fileserver tests ======================== ROOT_HTML = b""" @@ -182,51 +183,47 @@ def test_download_from_hf_hub_retry( @mock.patch.object(requests.Session, 'get') @mock.patch('os.makedirs') @mock.patch('builtins.open') -def test_download_from_cache_server(mock_open: MagicMock, - mock_makedirs: MagicMock, - mock_get: MagicMock): - cache_url = 'https://cache.com/' - model_name = 'model' - formatted_model_name = 'models--model' +def test_download_from_http_fileserver(mock_open: MagicMock, + mock_makedirs: MagicMock, + mock_get: MagicMock): + model_url = f'https://cache.com/models/model/' save_dir = 'save_dir/' mock_open.return_value = MagicMock() def _server_response(url: str, **kwargs: Dict[str, Any]): - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'): + if url == model_url: return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML) - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'): + if url == urljoin(model_url, 'file1'): return MagicMock(status_code=HTTPStatus.OK) - elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'): + elif url == urljoin(model_url, 'folder/'): return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML) - elif url == urljoin(cache_url, - f'{formatted_model_name}/blobs/folder/file2'): + elif url == urljoin(model_url, 'folder/file2'): return MagicMock(status_code=HTTPStatus.OK) else: return MagicMock(status_code=HTTPStatus.NOT_FOUND) mock_get.side_effect = _server_response - download_from_cache_server(model_name, cache_url, 'save_dir/') + download_from_http_fileserver(model_url, save_dir) - mock_open.assert_has_calls([ - mock.call(os.path.join(save_dir, formatted_model_name, 'blobs/file1'), - 'wb'), - mock.call( - os.path.join(save_dir, formatted_model_name, 'blobs/folder/file2'), - 'wb'), - ], - any_order=True) + mock_open.assert_has_calls( + [ + mock.call(os.path.join(save_dir, 'file1'), 'wb'), + mock.call(os.path.join(save_dir, 'folder/file2'), 'wb'), + ], + any_order=True, + ) @mock.patch.object(requests.Session, 'get') -def test_download_from_cache_server_unauthorized(mock_get: MagicMock): - cache_url = 'https://cache.com/' +def test_download_from_http_fileserver_unauthorized(mock_get: MagicMock): model_name = 'model' + cache_url = f'https://cache.com/models--{model_name}/blobs/' save_dir = 'save_dir/' mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED) with pytest.raises(PermissionError): - download_from_cache_server(model_name, cache_url, save_dir) + download_from_http_fileserver(cache_url, save_dir) @pytest.mark.parametrize(['exception', 'expected_attempts'], [ @@ -236,7 +233,7 @@ def test_download_from_cache_server_unauthorized(mock_get: MagicMock): ]) @mock.patch('tenacity.nap.time.sleep') @mock.patch('llmfoundry.utils.model_download_utils._recursive_download') -def test_download_from_cache_server_retry( +def test_download_from_http_fileserver_retry( mock_recursive_download: MagicMock, mock_sleep: MagicMock, # so the retry wait doesn't actually wait exception: BaseException, @@ -245,4 +242,4 @@ def test_download_from_cache_server_retry( mock_recursive_download.side_effect = exception with pytest.raises((tenacity.RetryError, exception.__class__)): - download_from_cache_server('model', 'cache_url', 'save_dir') + download_from_http_fileserver('cache_url', 'save_dir')