From a7b1e51c429b316b57df416f64df141c165d4095 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Dec 2024 01:02:36 +0100 Subject: [PATCH 1/6] fix: Conditionally pass Hugging Face tokens for download when repo is private --- changelog_entry.yaml | 6 + policyengine_core/data/dataset.py | 13 +- policyengine_core/simulations/simulation.py | 7 +- policyengine_core/tools/hugging_face.py | 45 +++++- tests/core/tools/test_hugging_face.py | 168 +++++++++++++++----- 5 files changed, 181 insertions(+), 58 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..5ab310e3 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: minor + changes: + changed: + - Refactor all Hugging Face downloads to use the same download function + - Attempt to determine whether a token is necessary before downloading from Hugging Face + - Add tests for this new functionality \ No newline at end of file diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 235884c6..3bc86a54 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -497,13 +497,8 @@ def download_from_huggingface( file=sys.stderr, ) - token = get_or_prompt_hf_token() - - hf_hub_download( - repo_id=f"{owner_name}/{model_name}", - repo_type="model", - filename=file_name, - local_dir=self.file_path.parent, - revision=version, - token=token, + hf_download( + repo=f"{owner_name}/{model_name}", + repo_filename=file_name, + version=version, ) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 6d60e3c0..3a4349fe 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1,7 +1,6 @@ import tempfile from typing import TYPE_CHECKING, Any, Dict, List, Type, Union -import numpy import numpy as np import pandas as pd from numpy.typing import ArrayLike @@ -159,7 +158,9 @@ def __init__( filename = filename.split("@")[0] else: version = None - dataset = download(owner + "/" + repo, filename, version) + dataset = hf_download( + owner + "/" + repo, filename, version + ) datasets_by_name = { dataset.name: dataset for dataset in self.datasets } @@ -1041,7 +1042,7 @@ def _cast_formula_result(self, value: Any, variable: str) -> ArrayLike: if variable.value_type == Enum and not isinstance(value, EnumArray): return variable.possible_values.encode(value) - if not isinstance(value, numpy.ndarray): + if not isinstance(value, np.ndarray): population = self.get_variable_population(variable.name) value = population.filled_array(value) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index a5433e7f..b733e981 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -1,4 +1,11 @@ -from huggingface_hub import hf_hub_download, login, HfApi +from huggingface_hub import ( + hf_hub_download, + model_info, + login, + HfApi, + ModelInfo, +) +from huggingface_hub.errors import RepositoryNotFoundError from getpass import getpass import os import warnings @@ -7,10 +14,38 @@ warnings.simplefilter("ignore") -def download(repo: str, repo_filename: str, version: str = None): - token = os.environ.get( - "HUGGING_FACE_TOKEN", - ) +def hf_download(repo: str, repo_filename: str, version: str = None): + """ + Download a dataset from the Hugging Face Hub. + + Args: + repo (str): The Hugging Face repo name, in format "{org}/{repo}". + repo_filename (str): The filename of the dataset. + version (str, optional): The version of the dataset. Defaults to None. + """ + # Attempt connection to Hugging Face model_info endpoint + # (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info) + # Unfortunately, this endpoint will 401 on a private repo, + # but also on a public repo with a malformed URL, etc. + # Assume a 401 means the token is required. + + try: + fetched_model_info: ModelInfo = model_info(repo) + is_repo_private: bool = fetched_model_info.private + except RepositoryNotFoundError as e: + # If this error type arises, it's likely the repo is private; see docs above + is_repo_private = True + pass + except Exception as e: + # Otherwise, there probably is just a download error + raise Exception( + f"Unable to download dataset {repo_filename} from Hugging Face. This may be because the repo " + + "is private, the URL is malformed, or the dataset does not exist." + ) + + token: str = None + if is_repo_private: + token: str = get_or_prompt_hf_token() return hf_hub_download( repo_id=repo, diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py index acaceb30..59a37e1c 100644 --- a/tests/core/tools/test_hugging_face.py +++ b/tests/core/tools/test_hugging_face.py @@ -1,61 +1,147 @@ import os import pytest from unittest.mock import patch -from policyengine_core.tools.hugging_face import get_or_prompt_hf_token +from huggingface_hub import ModelInfo +from huggingface_hub.errors import RepositoryNotFoundError +from policyengine_core.tools.hugging_face import ( + get_or_prompt_hf_token, + hf_download, +) -def test_get_token_from_environment(): - """Test retrieving token when it exists in environment variables""" - test_token = "test_token_123" - with patch.dict( - os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True - ): - result = get_or_prompt_hf_token() - assert result == test_token +class TestHfDownload: + def test_download_public_repo(self): + """Test downloading from a public repo""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" + with patch( + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + # Create mock ModelInfo object emulating public repo + test_id = 0 + mock_model_info.return_value = ModelInfo( + id=test_id, private=False + ) + + hf_download(test_repo, test_filename, test_version) + + mock_download.assert_called_with( + repo_id=test_repo, + repo_type="model", + filename=test_filename, + revision=test_version, + token=None, + ) -def test_get_token_from_user_input(): - """Test retrieving token via user input when not in environment""" - test_token = "user_input_token_456" + def test_download_private_repo(self): + """Test downloading from a private repo""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" - # Mock both empty environment and user input - with patch.dict(os.environ, {}, clear=True): with patch( - "policyengine_core.tools.hugging_face.getpass", - return_value=test_token, - ): - result = get_or_prompt_hf_token() - assert result == test_token + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + mock_model_info.side_effect = RepositoryNotFoundError( + "Test error" + ) + with patch( + "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" + ) as mock_token: + mock_token.return_value = "test_token" - # Verify token was stored in environment - assert os.environ.get("HUGGING_FACE_TOKEN") == test_token + hf_download(test_repo, test_filename, test_version) + mock_download.assert_called_with( + repo_id=test_repo, + repo_type="model", + filename=test_filename, + revision=test_version, + token=mock_token.return_value, + ) + def test_download_private_repo_no_token(self): + """Test handling of private repo with no token""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" -def test_empty_user_input(): - """Test handling of empty user input""" - with patch.dict(os.environ, {}, clear=True): with patch( - "policyengine_core.tools.hugging_face.getpass", return_value="" + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + mock_model_info.side_effect = RepositoryNotFoundError( + "Test error" + ) + with patch( + "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" + ) as mock_token: + mock_token.return_value = "" + + with pytest.raises(Exception): + hf_download(test_repo, test_filename, test_version) + mock_download.assert_not_called() + + +class TestGetOrPromptHfToken: + def test_get_token_from_environment(self): + """Test retrieving token when it exists in environment variables""" + test_token = "test_token_123" + with patch.dict( + os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True ): result = get_or_prompt_hf_token() - assert result == "" - assert os.environ.get("HUGGING_FACE_TOKEN") == "" + assert result == test_token + def test_get_token_from_user_input(self): + """Test retrieving token via user input when not in environment""" + test_token = "user_input_token_456" -def test_environment_variable_persistence(): - """Test that environment variable persists across multiple calls""" - test_token = "persistence_test_token" + # Mock both empty environment and user input + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + result = get_or_prompt_hf_token() + assert result == test_token - # First call with no environment variable - with patch.dict(os.environ, {}, clear=True): - with patch( - "policyengine_core.tools.hugging_face.getpass", - return_value=test_token, - ): - first_result = get_or_prompt_hf_token() + # Verify token was stored in environment + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token - # Second call should use environment variable - second_result = get_or_prompt_hf_token() + def test_empty_user_input(self): + """Test handling of empty user input""" + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", return_value="" + ): + result = get_or_prompt_hf_token() + assert result == "" + assert os.environ.get("HUGGING_FACE_TOKEN") == "" - assert first_result == second_result == test_token - assert os.environ.get("HUGGING_FACE_TOKEN") == test_token + def test_environment_variable_persistence(self): + """Test that environment variable persists across multiple calls""" + test_token = "persistence_test_token" + + # First call with no environment variable + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + first_result = get_or_prompt_hf_token() + + # Second call should use environment variable + second_result = get_or_prompt_hf_token() + + assert first_result == second_result == test_token + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token From d9a418b692c65217f1d21ad12195cdbf8aa94cdb Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Dec 2024 10:16:05 +0100 Subject: [PATCH 2/6] fix: Change dataset download func name --- policyengine_core/data/dataset.py | 2 +- policyengine_core/simulations/simulation.py | 2 +- policyengine_core/tools/hugging_face.py | 2 +- tests/core/tools/test_hugging_face.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 3bc86a54..566a45a8 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -497,7 +497,7 @@ def download_from_huggingface( file=sys.stderr, ) - hf_download( + download_huggingface_dataset( repo=f"{owner_name}/{model_name}", repo_filename=file_name, version=version, diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 3a4349fe..23fbe59e 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -158,7 +158,7 @@ def __init__( filename = filename.split("@")[0] else: version = None - dataset = hf_download( + dataset = download_huggingface_dataset( owner + "/" + repo, filename, version ) datasets_by_name = { diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index b733e981..47fb1906 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -14,7 +14,7 @@ warnings.simplefilter("ignore") -def hf_download(repo: str, repo_filename: str, version: str = None): +def download_huggingface_dataset(repo: str, repo_filename: str, version: str = None): """ Download a dataset from the Hugging Face Hub. diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py index 59a37e1c..90f8cae9 100644 --- a/tests/core/tools/test_hugging_face.py +++ b/tests/core/tools/test_hugging_face.py @@ -5,7 +5,7 @@ from huggingface_hub.errors import RepositoryNotFoundError from policyengine_core.tools.hugging_face import ( get_or_prompt_hf_token, - hf_download, + download_huggingface_dataset, ) @@ -28,7 +28,7 @@ def test_download_public_repo(self): id=test_id, private=False ) - hf_download(test_repo, test_filename, test_version) + download_huggingface_dataset(test_repo, test_filename, test_version) mock_download.assert_called_with( repo_id=test_repo, @@ -58,7 +58,7 @@ def test_download_private_repo(self): ) as mock_token: mock_token.return_value = "test_token" - hf_download(test_repo, test_filename, test_version) + download_huggingface_dataset(test_repo, test_filename, test_version) mock_download.assert_called_with( repo_id=test_repo, repo_type="model", @@ -88,7 +88,7 @@ def test_download_private_repo_no_token(self): mock_token.return_value = "" with pytest.raises(Exception): - hf_download(test_repo, test_filename, test_version) + download_huggingface_dataset(test_repo, test_filename, test_version) mock_download.assert_not_called() From 17995fb4e22f71145ec8b82d39dd5f9f3a0ef39c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Dec 2024 10:18:19 +0100 Subject: [PATCH 3/6] fix: Change comment, change var name --- policyengine_core/tools/hugging_face.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index 47fb1906..cbf41fdf 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -25,10 +25,9 @@ def download_huggingface_dataset(repo: str, repo_filename: str, version: str = N """ # Attempt connection to Hugging Face model_info endpoint # (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info) - # Unfortunately, this endpoint will 401 on a private repo, - # but also on a public repo with a malformed URL, etc. - # Assume a 401 means the token is required. - + # Attempt to fetch model info to determine if repo is private + # A RepositoryNotFoundError & 401 likely means the repo is private, + # but this error will also surface for public repos with malformed URL, etc. try: fetched_model_info: ModelInfo = model_info(repo) is_repo_private: bool = fetched_model_info.private @@ -43,16 +42,16 @@ def download_huggingface_dataset(repo: str, repo_filename: str, version: str = N + "is private, the URL is malformed, or the dataset does not exist." ) - token: str = None + authentication_token: str = None if is_repo_private: - token: str = get_or_prompt_hf_token() + authentication_token: str = get_or_prompt_hf_token() return hf_hub_download( repo_id=repo, repo_type="model", filename=repo_filename, revision=version, - token=token, + token=authentication_token, ) From 85083e2d7bd0b158373e7823c35fca93037a2323 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Dec 2024 10:21:33 +0100 Subject: [PATCH 4/6] fix: Change var names --- policyengine_core/simulations/simulation.py | 4 +++- tests/core/tools/test_hugging_face.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 23fbe59e..9965735c 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -159,7 +159,9 @@ def __init__( else: version = None dataset = download_huggingface_dataset( - owner + "/" + repo, filename, version + repo=f"{owner}/{repo}", + repo_filename=filename, + version=version, ) datasets_by_name = { dataset.name: dataset for dataset in self.datasets diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py index 90f8cae9..16c92f29 100644 --- a/tests/core/tools/test_hugging_face.py +++ b/tests/core/tools/test_hugging_face.py @@ -9,7 +9,7 @@ ) -class TestHfDownload: +class TestHuggingFaceDownload: def test_download_public_repo(self): """Test downloading from a public repo""" test_repo = "test_repo" @@ -91,7 +91,6 @@ def test_download_private_repo_no_token(self): download_huggingface_dataset(test_repo, test_filename, test_version) mock_download.assert_not_called() - class TestGetOrPromptHfToken: def test_get_token_from_environment(self): """Test retrieving token when it exists in environment variables""" From aae8cac15e98f2ebd18f2840684436f92c4059e7 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Dec 2024 10:24:40 +0100 Subject: [PATCH 5/6] chore: Lint --- policyengine_core/tools/hugging_face.py | 4 +++- tests/core/tools/test_hugging_face.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index cbf41fdf..b2c5924c 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -14,7 +14,9 @@ warnings.simplefilter("ignore") -def download_huggingface_dataset(repo: str, repo_filename: str, version: str = None): +def download_huggingface_dataset( + repo: str, repo_filename: str, version: str = None +): """ Download a dataset from the Hugging Face Hub. diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py index 16c92f29..813a1278 100644 --- a/tests/core/tools/test_hugging_face.py +++ b/tests/core/tools/test_hugging_face.py @@ -28,7 +28,9 @@ def test_download_public_repo(self): id=test_id, private=False ) - download_huggingface_dataset(test_repo, test_filename, test_version) + download_huggingface_dataset( + test_repo, test_filename, test_version + ) mock_download.assert_called_with( repo_id=test_repo, @@ -58,7 +60,9 @@ def test_download_private_repo(self): ) as mock_token: mock_token.return_value = "test_token" - download_huggingface_dataset(test_repo, test_filename, test_version) + download_huggingface_dataset( + test_repo, test_filename, test_version + ) mock_download.assert_called_with( repo_id=test_repo, repo_type="model", @@ -88,9 +92,12 @@ def test_download_private_repo_no_token(self): mock_token.return_value = "" with pytest.raises(Exception): - download_huggingface_dataset(test_repo, test_filename, test_version) + download_huggingface_dataset( + test_repo, test_filename, test_version + ) mock_download.assert_not_called() + class TestGetOrPromptHfToken: def test_get_token_from_environment(self): """Test retrieving token when it exists in environment variables""" From 6112925a6846858bcf1a02e498e4305c09620756 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 10 Dec 2024 17:12:10 +0100 Subject: [PATCH 6/6] fix: Remove unused imports --- policyengine_core/tools/hugging_face.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index b2c5924c..2c18585a 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -1,8 +1,6 @@ from huggingface_hub import ( hf_hub_download, model_info, - login, - HfApi, ModelInfo, ) from huggingface_hub.errors import RepositoryNotFoundError