Skip to content

Commit

Permalink
Merge pull request #325 from anth-volk/fix/huggingface-download-error
Browse files Browse the repository at this point in the history
Set local directory when downloading datasets from Hugging Face
  • Loading branch information
anth-volk authored Dec 20, 2024
2 parents cd0f149 + 0b8dc71 commit d4e1ebe
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
changed:
- Explicitly set local directory when downloading datasets from Hugging Face
1 change: 1 addition & 0 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,4 +501,5 @@ def download_from_huggingface(
repo=f"{owner_name}/{model_name}",
repo_filename=file_name,
version=version,
local_dir=self.file_path.parent,
)
7 changes: 6 additions & 1 deletion policyengine_core/tools/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@


def download_huggingface_dataset(
repo: str, repo_filename: str, version: str = None
repo: str,
repo_filename: str,
version: str = None,
local_dir: str | None = None,
):
"""
Download a dataset from the Hugging Face Hub.
Expand All @@ -22,6 +25,7 @@ def download_huggingface_dataset(
repo (str): The Hugging Face repo name, in format "{org}/{repo}".
repo_filename (str): The filename of the dataset.
version (str, optional): The version of the dataset. Defaults to None.
local_dir (str, optional): The local directory to save the dataset to. Defaults to None.
"""
# Attempt connection to Hugging Face model_info endpoint
# (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info)
Expand Down Expand Up @@ -52,6 +56,7 @@ def download_huggingface_dataset(
filename=repo_filename,
revision=version,
token=authentication_token,
local_dir=local_dir,
)


Expand Down
11 changes: 8 additions & 3 deletions tests/core/tools/test_hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_download_public_repo(self):
test_repo = "test_repo"
test_filename = "test_filename"
test_version = "test_version"
test_dir = "test_dir"

with patch(
"policyengine_core.tools.hugging_face.hf_hub_download"
Expand All @@ -29,14 +30,15 @@ def test_download_public_repo(self):
)

download_huggingface_dataset(
test_repo, test_filename, test_version
test_repo, test_filename, test_version, test_dir
)

mock_download.assert_called_with(
repo_id=test_repo,
repo_type="model",
filename=test_filename,
revision=test_version,
local_dir=test_dir,
token=None,
)

Expand All @@ -45,6 +47,7 @@ def test_download_private_repo(self):
test_repo = "test_repo"
test_filename = "test_filename"
test_version = "test_version"
test_dir = "test_dir"

with patch(
"policyengine_core.tools.hugging_face.hf_hub_download"
Expand All @@ -61,21 +64,23 @@ def test_download_private_repo(self):
mock_token.return_value = "test_token"

download_huggingface_dataset(
test_repo, test_filename, test_version
test_repo, test_filename, test_version, test_dir
)
mock_download.assert_called_with(
repo_id=test_repo,
repo_type="model",
filename=test_filename,
revision=test_version,
token=mock_token.return_value,
local_dir=test_dir,
)

def test_download_private_repo_no_token(self):
"""Test handling of private repo with no token"""
test_repo = "test_repo"
test_filename = "test_filename"
test_version = "test_version"
test_dir = "test_dir"

with patch(
"policyengine_core.tools.hugging_face.hf_hub_download"
Expand All @@ -93,7 +98,7 @@ def test_download_private_repo_no_token(self):

with pytest.raises(Exception):
download_huggingface_dataset(
test_repo, test_filename, test_version
test_repo, test_filename, test_version, test_dir
)
mock_download.assert_not_called()

Expand Down

0 comments on commit d4e1ebe

Please sign in to comment.