From 926c3b536d06b7be101afcd1fa4532b0e3bb5b29 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 19 Dec 2024 16:01:57 +0100 Subject: [PATCH] fix: Set local directory when downloading datasets from Hugging Face --- changelog_entry.yaml | 4 ++++ policyengine_core/data/dataset.py | 1 + policyengine_core/tools/hugging_face.py | 7 ++++++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..531d0bb9 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + changed: + - Explicitly set local directory when downloading datasets from Hugging Face \ No newline at end of file diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index 566a45a8..c6cec09f 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -501,4 +501,5 @@ def download_from_huggingface( repo=f"{owner_name}/{model_name}", repo_filename=file_name, version=version, + local_dir=self.file_path.parent, ) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index 2c18585a..d92f4447 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -13,7 +13,10 @@ def download_huggingface_dataset( - repo: str, repo_filename: str, version: str = None + repo: str, + repo_filename: str, + version: str = None, + local_dir: str | None = None, ): """ Download a dataset from the Hugging Face Hub. @@ -22,6 +25,7 @@ def download_huggingface_dataset( repo (str): The Hugging Face repo name, in format "{org}/{repo}". repo_filename (str): The filename of the dataset. version (str, optional): The version of the dataset. Defaults to None. + local_dir (str, optional): The local directory to save the dataset to. Defaults to None. """ # Attempt connection to Hugging Face model_info endpoint # (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info) @@ -52,6 +56,7 @@ def download_huggingface_dataset( filename=repo_filename, revision=version, token=authentication_token, + local_dir=local_dir, )