Skip to content

Commit

Permalink
Add retries and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrychen109 committed Nov 6, 2023
1 parent 148d6ba commit b80a56c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 8 deletions.
21 changes: 21 additions & 0 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from bs4 import BeautifulSoup
import huggingface_hub as hf_hub
import requests
import tenacity
from transformers.utils import (
WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME,
Expand All @@ -33,6 +34,12 @@
log = logging.getLogger(__name__)


@tenacity.retry(
retry=tenacity.retry_if_not_exception_type(
(ValueError, hf_hub.utils.RepositoryNotFoundError)),
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(min=1, max=10)
)
def download_from_hf_hub(
repo_id: str,
save_dir: Optional[str] = None,
Expand All @@ -52,6 +59,10 @@ def download_from_hf_hub(
available. Defaults to True.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.
Raises:
RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized.
ValueError: If the model repo doesn't contain any supported model weights.
"""
repo_files = set(hf_hub.list_repo_files(repo_id))

Expand Down Expand Up @@ -129,6 +140,7 @@ def _recursive_download(
Raises:
PermissionError: If the remote server returns a 401 Unauthorized status code.
ValueError: If the remote server returns a 404 Not Found status code.
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
"""

Expand All @@ -139,6 +151,10 @@ def _recursive_download(
raise PermissionError(
f'Not authorized to download file from {url}. Received status code {response.status_code}. '
)
elif response.status_code == HTTPStatus.NOT_FOUND:
raise ValueError(
f'Could not find file at {url}. Received status code {response.status_code}'
)
elif response.status_code != HTTPStatus.OK:
raise RuntimeError(
f'Could not download file from {url}. Received unexpected status code {response.status_code}'
Expand Down Expand Up @@ -166,6 +182,11 @@ def _recursive_download(
)


@tenacity.retry(
retry=tenacity.retry_if_not_exception_type((PermissionError, ValueError)),
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(min=1, max=10)
)
def download_from_cache_server(
model_name: str,
cache_base_url: str,
Expand Down
9 changes: 6 additions & 3 deletions scripts/misc/download_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
argparser.add_argument(
'--download-from', type=str, choices=['hf', 'cache'], default='hf'
)
argparser.add_argument('--token', type=str, default=os.getenv(HF_TOKEN_ENV_VAR))
argparser.add_argument('--save-dir', type=str, default=HUGGINGFACE_HUB_CACHE)
argparser.add_argument('--token', type=str,
default=os.getenv(HF_TOKEN_ENV_VAR))
argparser.add_argument('--save-dir', type=str,
default=HUGGINGFACE_HUB_CACHE)
argparser.add_argument('--cache-url', type=str, default=None)
argparser.add_argument('--ignore-cert', action='store_true', default=False)
argparser.add_argument(
Expand All @@ -38,7 +40,8 @@

args = argparser.parse_args(sys.argv[1:])
if args.download_from == 'hf':
download_from_hf_hub(args.model, save_dir=args.save_dir, token=args.token)
download_from_hf_hub(
args.model, save_dir=args.save_dir, token=args.token)
else:
try:
download_from_cache_server(
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
'boto3>=1.21.45,<2',
'huggingface-hub>=0.17.0,<1.0',
'beautifulsoup4>=4.12.2,<5', # required for model download utils
'tenacity>=8.2.3,<9',
]

extra_deps = {}
Expand Down Expand Up @@ -102,7 +103,8 @@
extra_deps['peft'] = [
'loralib==0.1.1', # lora core
'bitsandbytes==0.39.1', # 8bit
'scipy>=1.10.0,<=1.11.0', # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes
# bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes
'scipy>=1.10.0,<=1.11.0',
# TODO: pin peft when it stabilizes.
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'peft==0.4.0',
Expand Down
60 changes: 56 additions & 4 deletions tests/test_model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,28 @@
from unittest.mock import MagicMock
from urllib.parse import urljoin

from huggingface_hub.utils import RepositoryNotFoundError
import pytest
import tenacity
from transformers.utils import (
WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
)

from llmfoundry.utils.model_download_utils import (
download_from_cache_server,
download_from_hf_hub,
PYTORCH_WEIGHTS_NAME,
PYTORCH_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
DEFAULT_IGNORE_PATTERNS,
PYTORCH_WEIGHTS_PATTERN,
SAFE_WEIGHTS_PATTERN
)


# ======================== download_from_hf_hub tests ========================


@pytest.mark.parametrize(['prefer_safetensors', 'repo_files', 'expected_ignore_patterns'], [
[ # Should use default ignore if only safetensors available
True,
Expand Down Expand Up @@ -129,6 +136,32 @@ def test_download_from_hf_hub_no_weights(
mock_snapshot_download.assert_not_called()


@pytest.mark.parametrize(['exception', 'expected_attempts'], [
[requests.exceptions.RequestException(), 3],
[RepositoryNotFoundError(''), 1],
[ValueError(), 1],
])
@mock.patch('tenacity.nap.time.sleep')
@mock.patch('huggingface_hub.snapshot_download')
@mock.patch('huggingface_hub.list_repo_files')
def test_download_from_hf_hub_retry(
mock_list_repo_files: MagicMock,
mock_snapshot_download: MagicMock,
mock_sleep: MagicMock, # so the retry wait doesn't actually wait
exception: BaseException,
expected_attempts: int,
):
mock_list_repo_files.return_value = [SAFE_WEIGHTS_INDEX_NAME]
mock_snapshot_download.side_effect = exception

with pytest.raises((tenacity.RetryError, exception.__class__)):
download_from_hf_hub('test_repo_id')

assert mock_snapshot_download.call_count == expected_attempts


# ======================== download_from_cache_server tests ========================

ROOT_HTML = b"""
<!DOCTYPE html>
<html>
Expand Down Expand Up @@ -200,3 +233,22 @@ def test_download_from_cache_server_unauthorized(mock_get: MagicMock):
mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED)
with pytest.raises(PermissionError):
download_from_cache_server(model_name, cache_url, save_dir)


@pytest.mark.parametrize(['exception', 'expected_attempts'], [
[requests.exceptions.RequestException(), 3],
[PermissionError(), 1],
[ValueError(), 1],
])
@mock.patch('tenacity.nap.time.sleep')
@mock.patch('llmfoundry.utils.model_download_utils._recursive_download')
def test_download_from_cache_server_retry(
mock_recursive_download: MagicMock,
mock_sleep: MagicMock, # so the retry wait doesn't actually wait
exception: BaseException,
expected_attempts: int,
):
mock_recursive_download.side_effect = exception

with pytest.raises((tenacity.RetryError, exception.__class__)):
download_from_cache_server('model', 'cache_url', 'save_dir')

0 comments on commit b80a56c

Please sign in to comment.