Skip to content

Commit

Permalink
Merge pull request #321 from anth-volk/fix/publicize-huggingface
Browse files Browse the repository at this point in the history
Conditionally pass Hugging Face tokens for download when repo is private
  • Loading branch information
MaxGhenis authored Dec 18, 2024
2 parents c96bedf + 6112925 commit 54b1f18
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 59 deletions.
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: minor
changes:
changed:
- Refactor all Hugging Face downloads to use the same download function
- Attempt to determine whether a token is necessary before downloading from Hugging Face
- Add tests for this new functionality
13 changes: 4 additions & 9 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,8 @@ def download_from_huggingface(
file=sys.stderr,
)

token = get_or_prompt_hf_token()

hf_hub_download(
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
filename=file_name,
local_dir=self.file_path.parent,
revision=version,
token=token,
download_huggingface_dataset(
repo=f"{owner_name}/{model_name}",
repo_filename=file_name,
version=version,
)
9 changes: 6 additions & 3 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import tempfile
from typing import TYPE_CHECKING, Any, Dict, List, Type, Union

import numpy
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -159,7 +158,11 @@ def __init__(
filename = filename.split("@")[0]
else:
version = None
dataset = download(owner + "/" + repo, filename, version)
dataset = download_huggingface_dataset(
repo=f"{owner}/{repo}",
repo_filename=filename,
version=version,
)
datasets_by_name = {
dataset.name: dataset for dataset in self.datasets
}
Expand Down Expand Up @@ -1041,7 +1044,7 @@ def _cast_formula_result(self, value: Any, variable: str) -> ArrayLike:
if variable.value_type == Enum and not isinstance(value, EnumArray):
return variable.possible_values.encode(value)

if not isinstance(value, numpy.ndarray):
if not isinstance(value, np.ndarray):
population = self.get_variable_population(variable.name)
value = population.filled_array(value)

Expand Down
46 changes: 40 additions & 6 deletions policyengine_core/tools/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from huggingface_hub import hf_hub_download, login, HfApi
from huggingface_hub import (
hf_hub_download,
model_info,
ModelInfo,
)
from huggingface_hub.errors import RepositoryNotFoundError
from getpass import getpass
import os
import warnings
Expand All @@ -7,17 +12,46 @@
warnings.simplefilter("ignore")


def download(repo: str, repo_filename: str, version: str = None):
token = os.environ.get(
"HUGGING_FACE_TOKEN",
)
def download_huggingface_dataset(
repo: str, repo_filename: str, version: str = None
):
"""
Download a dataset from the Hugging Face Hub.
Args:
repo (str): The Hugging Face repo name, in format "{org}/{repo}".
repo_filename (str): The filename of the dataset.
version (str, optional): The version of the dataset. Defaults to None.
"""
# Attempt connection to Hugging Face model_info endpoint
# (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info)
# Attempt to fetch model info to determine if repo is private
# A RepositoryNotFoundError & 401 likely means the repo is private,
# but this error will also surface for public repos with malformed URL, etc.
try:
fetched_model_info: ModelInfo = model_info(repo)
is_repo_private: bool = fetched_model_info.private
except RepositoryNotFoundError as e:
# If this error type arises, it's likely the repo is private; see docs above
is_repo_private = True
pass
except Exception as e:
# Otherwise, there probably is just a download error
raise Exception(
f"Unable to download dataset {repo_filename} from Hugging Face. This may be because the repo "
+ "is private, the URL is malformed, or the dataset does not exist."
)

authentication_token: str = None
if is_repo_private:
authentication_token: str = get_or_prompt_hf_token()

return hf_hub_download(
repo_id=repo,
repo_type="model",
filename=repo_filename,
revision=version,
token=token,
token=authentication_token,
)


Expand Down
174 changes: 133 additions & 41 deletions tests/core/tools/test_hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,153 @@
import os
import pytest
from unittest.mock import patch
from policyengine_core.tools.hugging_face import get_or_prompt_hf_token
from huggingface_hub import ModelInfo
from huggingface_hub.errors import RepositoryNotFoundError
from policyengine_core.tools.hugging_face import (
get_or_prompt_hf_token,
download_huggingface_dataset,
)


def test_get_token_from_environment():
"""Test retrieving token when it exists in environment variables"""
test_token = "test_token_123"
with patch.dict(
os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True
):
result = get_or_prompt_hf_token()
assert result == test_token
class TestHuggingFaceDownload:
def test_download_public_repo(self):
"""Test downloading from a public repo"""
test_repo = "test_repo"
test_filename = "test_filename"
test_version = "test_version"

with patch(
"policyengine_core.tools.hugging_face.hf_hub_download"
) as mock_download:
with patch(
"policyengine_core.tools.hugging_face.model_info"
) as mock_model_info:
# Create mock ModelInfo object emulating public repo
test_id = 0
mock_model_info.return_value = ModelInfo(
id=test_id, private=False
)

download_huggingface_dataset(
test_repo, test_filename, test_version
)

mock_download.assert_called_with(
repo_id=test_repo,
repo_type="model",
filename=test_filename,
revision=test_version,
token=None,
)

def test_get_token_from_user_input():
"""Test retrieving token via user input when not in environment"""
test_token = "user_input_token_456"
def test_download_private_repo(self):
"""Test downloading from a private repo"""
test_repo = "test_repo"
test_filename = "test_filename"
test_version = "test_version"

# Mock both empty environment and user input
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
result = get_or_prompt_hf_token()
assert result == test_token
"policyengine_core.tools.hugging_face.hf_hub_download"
) as mock_download:
with patch(
"policyengine_core.tools.hugging_face.model_info"
) as mock_model_info:
mock_model_info.side_effect = RepositoryNotFoundError(
"Test error"
)
with patch(
"policyengine_core.tools.hugging_face.get_or_prompt_hf_token"
) as mock_token:
mock_token.return_value = "test_token"

# Verify token was stored in environment
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token
download_huggingface_dataset(
test_repo, test_filename, test_version
)
mock_download.assert_called_with(
repo_id=test_repo,
repo_type="model",
filename=test_filename,
revision=test_version,
token=mock_token.return_value,
)

def test_download_private_repo_no_token(self):
"""Test handling of private repo with no token"""
test_repo = "test_repo"
test_filename = "test_filename"
test_version = "test_version"

def test_empty_user_input():
"""Test handling of empty user input"""
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass", return_value=""
"policyengine_core.tools.hugging_face.hf_hub_download"
) as mock_download:
with patch(
"policyengine_core.tools.hugging_face.model_info"
) as mock_model_info:
mock_model_info.side_effect = RepositoryNotFoundError(
"Test error"
)
with patch(
"policyengine_core.tools.hugging_face.get_or_prompt_hf_token"
) as mock_token:
mock_token.return_value = ""

with pytest.raises(Exception):
download_huggingface_dataset(
test_repo, test_filename, test_version
)
mock_download.assert_not_called()


class TestGetOrPromptHfToken:
def test_get_token_from_environment(self):
"""Test retrieving token when it exists in environment variables"""
test_token = "test_token_123"
with patch.dict(
os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True
):
result = get_or_prompt_hf_token()
assert result == ""
assert os.environ.get("HUGGING_FACE_TOKEN") == ""
assert result == test_token

def test_get_token_from_user_input(self):
"""Test retrieving token via user input when not in environment"""
test_token = "user_input_token_456"

def test_environment_variable_persistence():
"""Test that environment variable persists across multiple calls"""
test_token = "persistence_test_token"
# Mock both empty environment and user input
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
result = get_or_prompt_hf_token()
assert result == test_token

# First call with no environment variable
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
first_result = get_or_prompt_hf_token()
# Verify token was stored in environment
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token

# Second call should use environment variable
second_result = get_or_prompt_hf_token()
def test_empty_user_input(self):
"""Test handling of empty user input"""
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass", return_value=""
):
result = get_or_prompt_hf_token()
assert result == ""
assert os.environ.get("HUGGING_FACE_TOKEN") == ""

assert first_result == second_result == test_token
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token
def test_environment_variable_persistence(self):
"""Test that environment variable persists across multiple calls"""
test_token = "persistence_test_token"

# First call with no environment variable
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
first_result = get_or_prompt_hf_token()

# Second call should use environment variable
second_result = get_or_prompt_hf_token()

assert first_result == second_result == test_token
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token

0 comments on commit 54b1f18

Please sign in to comment.