diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 15c6d268c2..17f0914ac4 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -16,6 +16,7 @@ import huggingface_hub as hf_hub import requests import tenacity +import yaml from bs4 import BeautifulSoup from requests.packages.urllib3.exceptions import InsecureRequestWarning from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME @@ -144,6 +145,7 @@ def _recursive_download( RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized. """ url = urljoin(base_url, path) + print(url) response = session.get(url, verify=(not ignore_cert)) if response.status_code == HTTPStatus.UNAUTHORIZED: @@ -160,7 +162,7 @@ def _recursive_download( ) # Assume that the URL points to a file if it does not end with a slash. - if not path.endswith('/'): + if not url.endswith('/'): save_path = os.path.join(save_dir, path) parent_dir = os.path.dirname(save_path) if not os.path.exists(parent_dir): @@ -175,6 +177,7 @@ def _recursive_download( # If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links # to download. child_links = _extract_links_from_html(response.content.decode()) + print(child_links) for child_link in child_links: _recursive_download(session, base_url, @@ -214,20 +217,19 @@ def download_from_http_fileserver( ignore_cert=ignore_cert) -def download_from_oras(registry: str, - path: str, +def download_from_oras(model: str, + config_file: str, + credentials_dir: str, save_dir: str, - username: str, - password: str, concurrency: int = 10): """Download from an OCI-compliant registry using oras. Args: - registry: The registry to download from. - path: The path to the model in the registry. - save_dir: The directory to save the downloaded files to. - username: The username to authenticate with. - password: The password to authenticate with. + model: The name of the model to download. + config_file: Path to a YAML config file that maps model names to registry paths. + credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three + files: `username`, `password`, and `registry`, each of which contains the corresponding credential. + save_dir: Path to the directory where files will be downloaded. concurrency: The number of concurrent downloads to run. """ if shutil.which(ORAS_CLI) is None: @@ -235,6 +237,25 @@ def download_from_oras(registry: str, f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' ) + def _read_secrets_file(secret_file_path: str,): + try: + with open(secret_file_path, encoding='utf-8') as f: + return f.read().strip() + except Exception as error: + raise ValueError( + f'secrets file {secret_file_path} failed to be read') from error + + secrets = {} + for secret in ['username', 'password', 'registry']: + secrets[secret] = _read_secrets_file( + os.path.join(credentials_dir, secret)) + + with open(config_file, 'r', encoding='utf-8') as f: + configs = yaml.safe_load(f.read()) + + path = configs[model] + registry = secrets['registry'] + def get_oras_cmd(username: Optional[str] = None, password: Optional[str] = None): cmd = [ @@ -256,7 +277,8 @@ def get_oras_cmd(username: Optional[str] = None, cmd_without_creds = get_oras_cmd() log.info(f'CMD for oras cli to run: {" ".join(cmd_without_creds)}') - cmd_to_run = get_oras_cmd(username=username, password=password) + cmd_to_run = get_oras_cmd(username=secrets['username'], + password=secrets['password']) try: subprocess.run(cmd_to_run, check=True) except subprocess.CalledProcessError as e: diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py index 441f8eab04..1913267e20 100644 --- a/scripts/misc/download_model.py +++ b/scripts/misc/download_model.py @@ -39,10 +39,9 @@ def add_hf_parser_arguments(parser: argparse.ArgumentParser) -> None: def add_oras_parser_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument('--registry', type=str, required=True) - parser.add_argument('--path', type=str, required=True) - parser.add_argument('--username', type=str, default='') - parser.add_argument('--password', type=str, default='') + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--config-file', type=str, required=True) + parser.add_argument('--credentials-dir', type=str, required=True) parser.add_argument('--concurrency', type=int, default=10) @@ -87,7 +86,7 @@ def parse_args() -> argparse.Namespace: if download_from == 'http': try: - download_from_http_fileserver(args.host, args.path, args.save_dir, + download_from_http_fileserver(args.url, args.save_dir, args.ignore_cert) except PermissionError as e: log.error(f'Not authorized to download {args.model}.') @@ -112,5 +111,5 @@ def parse_args() -> argparse.Namespace: token=args.token, prefer_safetensors=args.prefer_safetensors) elif download_from == 'oras': - download_from_oras(args.registry, args.path, args.save_dir, - args.username, args.password, args.concurrency) + download_from_oras(args.model, args.config_file, args.credentials_dir, + args.save_dir, args.concurrency) diff --git a/tests/utils/test_model_download_utils.py b/tests/utils/test_model_download_utils.py index c15a1c0b6a..471a39dcdb 100644 --- a/tests/utils/test_model_download_utils.py +++ b/tests/utils/test_model_download_utils.py @@ -186,55 +186,44 @@ def test_download_from_hf_hub_retry( def test_download_from_http_fileserver(mock_open: MagicMock, mock_makedirs: MagicMock, mock_get: MagicMock): - cache_url = 'https://cache.com/' - formatted_model_name = 'models--model' + model_url = f'https://cache.com/models/model/' save_dir = 'save_dir/' mock_open.return_value = MagicMock() def _server_response(url: str, **kwargs: Dict[str, Any]): - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'): + if url == model_url: return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML) - if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'): + if url == urljoin(model_url, 'file1'): return MagicMock(status_code=HTTPStatus.OK) - elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'): + elif url == urljoin(model_url, 'folder/'): return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML) - elif url == urljoin(cache_url, - f'{formatted_model_name}/blobs/folder/file2'): + elif url == urljoin(model_url, 'folder/file2'): return MagicMock(status_code=HTTPStatus.OK) else: return MagicMock(status_code=HTTPStatus.NOT_FOUND) mock_get.side_effect = _server_response - download_from_http_fileserver( - cache_url, - formatted_model_name + '/blobs/', - 'save_dir/', - ) + download_from_http_fileserver(model_url, save_dir) mock_open.assert_has_calls( [ - mock.call( - os.path.join(save_dir, formatted_model_name, 'blobs/file1'), - 'wb'), - mock.call( - os.path.join(save_dir, formatted_model_name, - 'blobs/folder/file2'), 'wb'), + mock.call(os.path.join(save_dir, 'file1'), 'wb'), + mock.call(os.path.join(save_dir, 'folder/file2'), 'wb'), ], any_order=True, ) @mock.patch.object(requests.Session, 'get') -def test_download_from_cache_server_unauthorized(mock_get: MagicMock): - cache_url = 'https://cache.com/' +def test_download_from_http_fileserver_unauthorized(mock_get: MagicMock): model_name = 'model' + cache_url = f'https://cache.com/models--{model_name}/blobs/' save_dir = 'save_dir/' mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED) with pytest.raises(PermissionError): - download_from_http_fileserver(cache_url, f'models--{model_name}/blobs/', - save_dir) + download_from_http_fileserver(cache_url, save_dir) @pytest.mark.parametrize(['exception', 'expected_attempts'], [ @@ -244,7 +233,7 @@ def test_download_from_cache_server_unauthorized(mock_get: MagicMock): ]) @mock.patch('tenacity.nap.time.sleep') @mock.patch('llmfoundry.utils.model_download_utils._recursive_download') -def test_download_from_cache_server_retry( +def test_download_from_http_fileserver_retry( mock_recursive_download: MagicMock, mock_sleep: MagicMock, # so the retry wait doesn't actually wait exception: BaseException, @@ -253,4 +242,4 @@ def test_download_from_cache_server_retry( mock_recursive_download.side_effect = exception with pytest.raises((tenacity.RetryError, exception.__class__)): - download_from_http_fileserver('cache_url', 'models--model/', 'save_dir') + download_from_http_fileserver('cache_url', 'save_dir')