diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..5ab310e3 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: minor + changes: + changed: + - Refactor all Hugging Face downloads to use the same download function + - Attempt to determine whether a token is necessary before downloading from Hugging Face + - Add tests for this new functionality \ No newline at end of file diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 235884c6..566a45a8 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -497,13 +497,8 @@ def download_from_huggingface( file=sys.stderr, ) - token = get_or_prompt_hf_token() - - hf_hub_download( - repo_id=f"{owner_name}/{model_name}", - repo_type="model", - filename=file_name, - local_dir=self.file_path.parent, - revision=version, - token=token, + download_huggingface_dataset( + repo=f"{owner_name}/{model_name}", + repo_filename=file_name, + version=version, ) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 6d60e3c0..9965735c 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1,7 +1,6 @@ import tempfile from typing import TYPE_CHECKING, Any, Dict, List, Type, Union -import numpy import numpy as np import pandas as pd from numpy.typing import ArrayLike @@ -159,7 +158,11 @@ def __init__( filename = filename.split("@")[0] else: version = None - dataset = download(owner + "/" + repo, filename, version) + dataset = download_huggingface_dataset( + repo=f"{owner}/{repo}", + repo_filename=filename, + version=version, + ) datasets_by_name = { dataset.name: dataset for dataset in self.datasets } @@ -1041,7 +1044,7 @@ def _cast_formula_result(self, value: Any, variable: str) -> ArrayLike: if variable.value_type == Enum and not isinstance(value, EnumArray): return variable.possible_values.encode(value) - if not isinstance(value, numpy.ndarray): + if not isinstance(value, np.ndarray): population = self.get_variable_population(variable.name) value = population.filled_array(value) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index a5433e7f..2c18585a 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -1,4 +1,9 @@ -from huggingface_hub import hf_hub_download, login, HfApi +from huggingface_hub import ( + hf_hub_download, + model_info, + ModelInfo, +) +from huggingface_hub.errors import RepositoryNotFoundError from getpass import getpass import os import warnings @@ -7,17 +12,46 @@ warnings.simplefilter("ignore") -def download(repo: str, repo_filename: str, version: str = None): - token = os.environ.get( - "HUGGING_FACE_TOKEN", - ) +def download_huggingface_dataset( + repo: str, repo_filename: str, version: str = None +): + """ + Download a dataset from the Hugging Face Hub. + + Args: + repo (str): The Hugging Face repo name, in format "{org}/{repo}". + repo_filename (str): The filename of the dataset. + version (str, optional): The version of the dataset. Defaults to None. + """ + # Attempt connection to Hugging Face model_info endpoint + # (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info) + # Attempt to fetch model info to determine if repo is private + # A RepositoryNotFoundError & 401 likely means the repo is private, + # but this error will also surface for public repos with malformed URL, etc. + try: + fetched_model_info: ModelInfo = model_info(repo) + is_repo_private: bool = fetched_model_info.private + except RepositoryNotFoundError as e: + # If this error type arises, it's likely the repo is private; see docs above + is_repo_private = True + pass + except Exception as e: + # Otherwise, there probably is just a download error + raise Exception( + f"Unable to download dataset {repo_filename} from Hugging Face. This may be because the repo " + + "is private, the URL is malformed, or the dataset does not exist." + ) + + authentication_token: str = None + if is_repo_private: + authentication_token: str = get_or_prompt_hf_token() return hf_hub_download( repo_id=repo, repo_type="model", filename=repo_filename, revision=version, - token=token, + token=authentication_token, ) diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py index acaceb30..813a1278 100644 --- a/tests/core/tools/test_hugging_face.py +++ b/tests/core/tools/test_hugging_face.py @@ -1,61 +1,153 @@ import os import pytest from unittest.mock import patch -from policyengine_core.tools.hugging_face import get_or_prompt_hf_token +from huggingface_hub import ModelInfo +from huggingface_hub.errors import RepositoryNotFoundError +from policyengine_core.tools.hugging_face import ( + get_or_prompt_hf_token, + download_huggingface_dataset, +) -def test_get_token_from_environment(): - """Test retrieving token when it exists in environment variables""" - test_token = "test_token_123" - with patch.dict( - os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True - ): - result = get_or_prompt_hf_token() - assert result == test_token +class TestHuggingFaceDownload: + def test_download_public_repo(self): + """Test downloading from a public repo""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" + with patch( + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + # Create mock ModelInfo object emulating public repo + test_id = 0 + mock_model_info.return_value = ModelInfo( + id=test_id, private=False + ) + + download_huggingface_dataset( + test_repo, test_filename, test_version + ) + + mock_download.assert_called_with( + repo_id=test_repo, + repo_type="model", + filename=test_filename, + revision=test_version, + token=None, + ) -def test_get_token_from_user_input(): - """Test retrieving token via user input when not in environment""" - test_token = "user_input_token_456" + def test_download_private_repo(self): + """Test downloading from a private repo""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" - # Mock both empty environment and user input - with patch.dict(os.environ, {}, clear=True): with patch( - "policyengine_core.tools.hugging_face.getpass", - return_value=test_token, - ): - result = get_or_prompt_hf_token() - assert result == test_token + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + mock_model_info.side_effect = RepositoryNotFoundError( + "Test error" + ) + with patch( + "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" + ) as mock_token: + mock_token.return_value = "test_token" - # Verify token was stored in environment - assert os.environ.get("HUGGING_FACE_TOKEN") == test_token + download_huggingface_dataset( + test_repo, test_filename, test_version + ) + mock_download.assert_called_with( + repo_id=test_repo, + repo_type="model", + filename=test_filename, + revision=test_version, + token=mock_token.return_value, + ) + def test_download_private_repo_no_token(self): + """Test handling of private repo with no token""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" -def test_empty_user_input(): - """Test handling of empty user input""" - with patch.dict(os.environ, {}, clear=True): with patch( - "policyengine_core.tools.hugging_face.getpass", return_value="" + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + mock_model_info.side_effect = RepositoryNotFoundError( + "Test error" + ) + with patch( + "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" + ) as mock_token: + mock_token.return_value = "" + + with pytest.raises(Exception): + download_huggingface_dataset( + test_repo, test_filename, test_version + ) + mock_download.assert_not_called() + + +class TestGetOrPromptHfToken: + def test_get_token_from_environment(self): + """Test retrieving token when it exists in environment variables""" + test_token = "test_token_123" + with patch.dict( + os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True ): result = get_or_prompt_hf_token() - assert result == "" - assert os.environ.get("HUGGING_FACE_TOKEN") == "" + assert result == test_token + def test_get_token_from_user_input(self): + """Test retrieving token via user input when not in environment""" + test_token = "user_input_token_456" -def test_environment_variable_persistence(): - """Test that environment variable persists across multiple calls""" - test_token = "persistence_test_token" + # Mock both empty environment and user input + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + result = get_or_prompt_hf_token() + assert result == test_token - # First call with no environment variable - with patch.dict(os.environ, {}, clear=True): - with patch( - "policyengine_core.tools.hugging_face.getpass", - return_value=test_token, - ): - first_result = get_or_prompt_hf_token() + # Verify token was stored in environment + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token - # Second call should use environment variable - second_result = get_or_prompt_hf_token() + def test_empty_user_input(self): + """Test handling of empty user input""" + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", return_value="" + ): + result = get_or_prompt_hf_token() + assert result == "" + assert os.environ.get("HUGGING_FACE_TOKEN") == "" - assert first_result == second_result == test_token - assert os.environ.get("HUGGING_FACE_TOKEN") == test_token + def test_environment_variable_persistence(self): + """Test that environment variable persists across multiple calls""" + test_token = "persistence_test_token" + + # First call with no environment variable + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + first_result = get_or_prompt_hf_token() + + # Second call should use environment variable + second_result = get_or_prompt_hf_token() + + assert first_result == second_result == test_token + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token