Skip to content

Commit

Permalink
Use files for ORAS
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrychen109 committed Jan 17, 2024
1 parent 1b2ad71 commit 697638f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 42 deletions.
44 changes: 33 additions & 11 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import huggingface_hub as hf_hub
import requests
import tenacity
import yaml
from bs4 import BeautifulSoup
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
Expand Down Expand Up @@ -144,6 +145,7 @@ def _recursive_download(
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
"""
url = urljoin(base_url, path)
print(url)
response = session.get(url, verify=(not ignore_cert))

if response.status_code == HTTPStatus.UNAUTHORIZED:
Expand All @@ -160,7 +162,7 @@ def _recursive_download(
)

# Assume that the URL points to a file if it does not end with a slash.
if not path.endswith('/'):
if not url.endswith('/'):
save_path = os.path.join(save_dir, path)
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
Expand All @@ -175,6 +177,7 @@ def _recursive_download(
# If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links
# to download.
child_links = _extract_links_from_html(response.content.decode())
print(child_links)
for child_link in child_links:
_recursive_download(session,
base_url,
Expand Down Expand Up @@ -214,27 +217,45 @@ def download_from_http_fileserver(
ignore_cert=ignore_cert)


def download_from_oras(registry: str,
path: str,
def download_from_oras(model: str,
config_file: str,
credentials_dir: str,
save_dir: str,
username: str,
password: str,
concurrency: int = 10):
"""Download from an OCI-compliant registry using oras.
Args:
registry: The registry to download from.
path: The path to the model in the registry.
save_dir: The directory to save the downloaded files to.
username: The username to authenticate with.
password: The password to authenticate with.
model: The name of the model to download.
config_file: Path to a YAML config file that maps model names to registry paths.
credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
save_dir: Path to the directory where files will be downloaded.
concurrency: The number of concurrent downloads to run.
"""
if shutil.which(ORAS_CLI) is None:
raise Exception(
f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation '
)

def _read_secrets_file(secret_file_path: str,):
try:
with open(secret_file_path, encoding='utf-8') as f:
return f.read().strip()
except Exception as error:
raise ValueError(
f'secrets file {secret_file_path} failed to be read') from error

secrets = {}
for secret in ['username', 'password', 'registry']:
secrets[secret] = _read_secrets_file(
os.path.join(credentials_dir, secret))

with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())

path = configs[model]
registry = secrets['registry']

def get_oras_cmd(username: Optional[str] = None,
password: Optional[str] = None):
cmd = [
Expand All @@ -256,7 +277,8 @@ def get_oras_cmd(username: Optional[str] = None,

cmd_without_creds = get_oras_cmd()
log.info(f'CMD for oras cli to run: {" ".join(cmd_without_creds)}')
cmd_to_run = get_oras_cmd(username=username, password=password)
cmd_to_run = get_oras_cmd(username=secrets['username'],
password=secrets['password'])
try:
subprocess.run(cmd_to_run, check=True)
except subprocess.CalledProcessError as e:
Expand Down
13 changes: 6 additions & 7 deletions scripts/misc/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ def add_hf_parser_arguments(parser: argparse.ArgumentParser) -> None:


def add_oras_parser_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument('--registry', type=str, required=True)
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--username', type=str, default='')
parser.add_argument('--password', type=str, default='')
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--config-file', type=str, required=True)
parser.add_argument('--credentials-dir', type=str, required=True)
parser.add_argument('--concurrency', type=int, default=10)


Expand Down Expand Up @@ -87,7 +86,7 @@ def parse_args() -> argparse.Namespace:

if download_from == 'http':
try:
download_from_http_fileserver(args.host, args.path, args.save_dir,
download_from_http_fileserver(args.url, args.save_dir,
args.ignore_cert)
except PermissionError as e:
log.error(f'Not authorized to download {args.model}.')
Expand All @@ -112,5 +111,5 @@ def parse_args() -> argparse.Namespace:
token=args.token,
prefer_safetensors=args.prefer_safetensors)
elif download_from == 'oras':
download_from_oras(args.registry, args.path, args.save_dir,
args.username, args.password, args.concurrency)
download_from_oras(args.model, args.config_file, args.credentials_dir,
args.save_dir, args.concurrency)
37 changes: 13 additions & 24 deletions tests/utils/test_model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,55 +186,44 @@ def test_download_from_hf_hub_retry(
def test_download_from_http_fileserver(mock_open: MagicMock,
mock_makedirs: MagicMock,
mock_get: MagicMock):
cache_url = 'https://cache.com/'
formatted_model_name = 'models--model'
model_url = f'https://cache.com/models/model/'
save_dir = 'save_dir/'

mock_open.return_value = MagicMock()

def _server_response(url: str, **kwargs: Dict[str, Any]):
if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'):
if url == model_url:
return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML)
if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'):
if url == urljoin(model_url, 'file1'):
return MagicMock(status_code=HTTPStatus.OK)
elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'):
elif url == urljoin(model_url, 'folder/'):
return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML)
elif url == urljoin(cache_url,
f'{formatted_model_name}/blobs/folder/file2'):
elif url == urljoin(model_url, 'folder/file2'):
return MagicMock(status_code=HTTPStatus.OK)
else:
return MagicMock(status_code=HTTPStatus.NOT_FOUND)

mock_get.side_effect = _server_response
download_from_http_fileserver(
cache_url,
formatted_model_name + '/blobs/',
'save_dir/',
)
download_from_http_fileserver(model_url, save_dir)

mock_open.assert_has_calls(
[
mock.call(
os.path.join(save_dir, formatted_model_name, 'blobs/file1'),
'wb'),
mock.call(
os.path.join(save_dir, formatted_model_name,
'blobs/folder/file2'), 'wb'),
mock.call(os.path.join(save_dir, 'file1'), 'wb'),
mock.call(os.path.join(save_dir, 'folder/file2'), 'wb'),
],
any_order=True,
)


@mock.patch.object(requests.Session, 'get')
def test_download_from_cache_server_unauthorized(mock_get: MagicMock):
cache_url = 'https://cache.com/'
def test_download_from_http_fileserver_unauthorized(mock_get: MagicMock):
model_name = 'model'
cache_url = f'https://cache.com/models--{model_name}/blobs/'
save_dir = 'save_dir/'

mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED)
with pytest.raises(PermissionError):
download_from_http_fileserver(cache_url, f'models--{model_name}/blobs/',
save_dir)
download_from_http_fileserver(cache_url, save_dir)


@pytest.mark.parametrize(['exception', 'expected_attempts'], [
Expand All @@ -244,7 +233,7 @@ def test_download_from_cache_server_unauthorized(mock_get: MagicMock):
])
@mock.patch('tenacity.nap.time.sleep')
@mock.patch('llmfoundry.utils.model_download_utils._recursive_download')
def test_download_from_cache_server_retry(
def test_download_from_http_fileserver_retry(
mock_recursive_download: MagicMock,
mock_sleep: MagicMock, # so the retry wait doesn't actually wait
exception: BaseException,
Expand All @@ -253,4 +242,4 @@ def test_download_from_cache_server_retry(
mock_recursive_download.side_effect = exception

with pytest.raises((tenacity.RetryError, exception.__class__)):
download_from_http_fileserver('cache_url', 'models--model/', 'save_dir')
download_from_http_fileserver('cache_url', 'save_dir')

0 comments on commit 697638f

Please sign in to comment.