Skip to content

Commit

Permalink
Unit tests for download_from_cache_server
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrychen109 committed Nov 6, 2023
1 parent 5c38bc5 commit 01d233b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
8 changes: 5 additions & 3 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing import Optional
from http import HTTPStatus
from urllib.parse import urljoin

from bs4 import BeautifulSoup
import huggingface_hub as hf_hub
Expand Down Expand Up @@ -130,7 +131,8 @@ def _recursive_download(
PermissionError: If the remote server returns a 401 Unauthorized status code.
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
"""
url = f'{base_url}/{path}'

url = urljoin(base_url, path)
response = session.get(url, verify=(not ignore_cert))

if response.status_code == HTTPStatus.UNAUTHORIZED:
Expand All @@ -144,7 +146,7 @@ def _recursive_download(

# Assume that the URL points to a file if it does not end with a slash.
if not path.endswith('/'):
save_path = f'{save_dir}/{path}'
save_path = os.path.join(save_dir, path)
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
Expand All @@ -160,7 +162,7 @@ def _recursive_download(
child_links = _extract_links_from_html(response.content.decode())
for child_link in child_links:
_recursive_download(
session, base_url, f'{path}/{child_link}', save_dir, ignore_cert=ignore_cert
session, base_url, urljoin(path, child_link), save_dir, ignore_cert=ignore_cert
)


Expand Down
79 changes: 78 additions & 1 deletion tests/test_model_download_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright 2023 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List
from typing import Any, Dict, List

from http import HTTPStatus
import os
import requests
import unittest.mock as mock
from unittest.mock import MagicMock
from urllib.parse import urljoin

import pytest

Expand Down Expand Up @@ -123,3 +127,76 @@ def test_download_from_hf_hub_no_weights(
download_from_hf_hub(test_repo_id)

mock_snapshot_download.assert_not_called()


ROOT_HTML = b"""
<!DOCTYPE html>
<html>
<body>
<ul>
<li><a href="file1">file1</a></li>
<li><a href="folder/">folder/</a></li>
</ul>
</body>
</html>
"""

SUBFOLDER_HTML = b"""
<!DOCTYPE html>
<html>
<body>
<ul>
<li><a href="file2">file2</a></li>
</ul>
</body>
</html>
"""


@mock.patch.object(requests.Session, 'get')
@mock.patch('os.makedirs')
@mock.patch('builtins.open')
def test_download_from_cache_server(
mock_open: MagicMock,
mock_makedirs: MagicMock,
mock_get: MagicMock
):
cache_url = 'https://cache.com/'
model_name = 'model'
formatted_model_name = 'models--model'
save_dir = 'save_dir/'

mock_open.return_value = MagicMock()

def _server_response(url: str, **kwargs: Dict[str, Any]):
if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'):
return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML)
if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'):
return MagicMock(status_code=HTTPStatus.OK)
elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'):
return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML)
elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/file2'):
return MagicMock(status_code=HTTPStatus.OK)
else:
return MagicMock(status_code=HTTPStatus.NOT_FOUND)

mock_get.side_effect = _server_response
download_from_cache_server(model_name, cache_url, 'save_dir/')

mock_open.assert_has_calls([
mock.call(os.path.join(
save_dir, formatted_model_name, 'blobs/file1'), 'wb'),
mock.call(os.path.join(save_dir, formatted_model_name,
'blobs/folder/file2'), 'wb'),
], any_order=True)


@mock.patch.object(requests.Session, 'get')
def test_download_from_cache_server_unauthorized(mock_get: MagicMock):
cache_url = 'https://cache.com/'
model_name = 'model'
save_dir = 'save_dir/'

mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED)
with pytest.raises(PermissionError):
download_from_cache_server(model_name, cache_url, save_dir)

0 comments on commit 01d233b

Please sign in to comment.