From 926c3b536d06b7be101afcd1fa4532b0e3bb5b29 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Dec 2024 16:01:57 +0100 Subject: [PATCH 1/2] fix: Set local directory when downloading datasets from Hugging Face --- changelog_entry.yaml | 4 ++++ policyengine_core/data/dataset.py | 1 + policyengine_core/tools/hugging_face.py | 7 ++++++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..531d0bb9 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + changed: + - Explicitly set local directory when downloading datasets from Hugging Face \ No newline at end of file diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 566a45a8..c6cec09f 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -501,4 +501,5 @@ def download_from_huggingface( repo=f"{owner_name}/{model_name}", repo_filename=file_name, version=version, + local_dir=self.file_path.parent, ) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index 2c18585a..d92f4447 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -13,7 +13,10 @@ def download_huggingface_dataset( - repo: str, repo_filename: str, version: str = None + repo: str, + repo_filename: str, + version: str = None, + local_dir: str | None = None, ): """ Download a dataset from the Hugging Face Hub. @@ -22,6 +25,7 @@ def download_huggingface_dataset( repo (str): The Hugging Face repo name, in format "{org}/{repo}". repo_filename (str): The filename of the dataset. version (str, optional): The version of the dataset. Defaults to None. + local_dir (str, optional): The local directory to save the dataset to. Defaults to None. """ # Attempt connection to Hugging Face model_info endpoint # (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info) @@ -52,6 +56,7 @@ def download_huggingface_dataset( filename=repo_filename, revision=version, token=authentication_token, + local_dir=local_dir, ) From 0b8dc7174290290111f699c5295d4a4e897eba0b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Dec 2024 20:59:25 +0100 Subject: [PATCH 2/2] fix: Update tests --- tests/core/tools/test_hugging_face.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py index 813a1278..54f04bd8 100644 --- a/tests/core/tools/test_hugging_face.py +++ b/tests/core/tools/test_hugging_face.py @@ -15,6 +15,7 @@ def test_download_public_repo(self): test_repo = "test_repo" test_filename = "test_filename" test_version = "test_version" + test_dir = "test_dir" with patch( "policyengine_core.tools.hugging_face.hf_hub_download" @@ -29,7 +30,7 @@ def test_download_public_repo(self): ) download_huggingface_dataset( - test_repo, test_filename, test_version + test_repo, test_filename, test_version, test_dir ) mock_download.assert_called_with( @@ -37,6 +38,7 @@ def test_download_public_repo(self): repo_type="model", filename=test_filename, revision=test_version, + local_dir=test_dir, token=None, ) @@ -45,6 +47,7 @@ def test_download_private_repo(self): test_repo = "test_repo" test_filename = "test_filename" test_version = "test_version" + test_dir = "test_dir" with patch( "policyengine_core.tools.hugging_face.hf_hub_download" @@ -61,7 +64,7 @@ def test_download_private_repo(self): mock_token.return_value = "test_token" download_huggingface_dataset( - test_repo, test_filename, test_version + test_repo, test_filename, test_version, test_dir ) mock_download.assert_called_with( repo_id=test_repo, @@ -69,6 +72,7 @@ def test_download_private_repo(self): filename=test_filename, revision=test_version, token=mock_token.return_value, + local_dir=test_dir, ) def test_download_private_repo_no_token(self): @@ -76,6 +80,7 @@ def test_download_private_repo_no_token(self): test_repo = "test_repo" test_filename = "test_filename" test_version = "test_version" + test_dir = "test_dir" with patch( "policyengine_core.tools.hugging_face.hf_hub_download" @@ -93,7 +98,7 @@ def test_download_private_repo_no_token(self): with pytest.raises(Exception): download_huggingface_dataset( - test_repo, test_filename, test_version + test_repo, test_filename, test_version, test_dir ) mock_download.assert_not_called()