diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..d8cfc7df 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: patch + changes: + changed: + - Replaced coexistent standard/HuggingFace URL with standalone URL parameter + - Fixed bugs in download_from_huggingface() method + - Create utility function for pulling HuggingFace env var \ No newline at end of file diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index d648417e..235884c6 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -8,6 +8,7 @@ import os import tempfile from policyengine_core.tools.hugging_face import * +import sys def atomic_write(file: Path, content: bytes) -> None: @@ -54,8 +55,6 @@ class Dataset: """The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`.""" url: str = None """The URL to download the dataset from. This is used to download the dataset if it does not exist.""" - huggingface_url: str = None - """The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist.""" # Data formats TABLES = "tables" @@ -317,7 +316,7 @@ def download(self, url: str = None, version: str = None) -> None: """ if url is None: - url = self.url or self.huggingface_url + url = self.url if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ: auth_headers = {} @@ -349,8 +348,10 @@ def download(self, url: str = None, version: str = None) -> None: f"File {file_path} not found in release {release_tag} of {org}/{repo}." ) elif url.startswith("hf://"): - owner_name, model_name = url.split("/")[2:] - self.download_from_huggingface(owner_name, model_name, version) + owner_name, model_name, file_name = url.split("/")[2:] + self.download_from_huggingface( + owner_name, model_name, file_name, version + ) return else: url = url @@ -377,11 +378,11 @@ def upload(self, url: str = None): url (str): The url to upload. """ if url is None: - url = self.huggingface_url or self.url + url = self.url if url.startswith("hf://"): - owner_name, model_name = url.split("/")[2:] - self.upload_to_huggingface(owner_name, model_name) + owner_name, model_name, file_name = url.split("/")[2:] + self.upload_to_huggingface(owner_name, model_name, file_name) def remove(self): """Removes the dataset from disk.""" @@ -451,43 +452,58 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None): return dataset - def upload_to_huggingface(self, owner_name: str, model_name: str): - """Uploads the dataset to Hugging Face. + def upload_to_huggingface( + self, owner_name: str, model_name: str, file_name: str + ): + """Uploads the dataset to HuggingFace. Args: owner_name (str): The owner name. model_name (str): The model name. """ - token = os.environ.get( - "HUGGING_FACE_TOKEN", + + print( + f"Uploading to HuggingFace {owner_name}/{model_name}/{file_name}", + file=sys.stderr, ) + + token = get_or_prompt_hf_token() api = HfApi() api.upload_file( path_or_fileobj=self.file_path, - path_in_repo=self.file_path.name, + path_in_repo=file_name, repo_id=f"{owner_name}/{model_name}", repo_type="model", token=token, ) def download_from_huggingface( - self, owner_name: str, model_name: str, version: str = None + self, + owner_name: str, + model_name: str, + file_name: str, + version: str = None, ): - """Downloads the dataset from Hugging Face. + """Downloads the dataset from HuggingFace. Args: owner_name (str): The owner name. model_name (str): The model name. """ - token = os.environ.get( - "HUGGING_FACE_TOKEN", + + print( + f"Downloading from HuggingFace {owner_name}/{model_name}/{file_name}", + file=sys.stderr, ) + token = get_or_prompt_hf_token() + hf_hub_download( repo_id=f"{owner_name}/{model_name}", repo_type="model", - path=self.file_path, + filename=file_name, + local_dir=self.file_path.parent, revision=version, token=token, ) diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index b43df592..a5433e7f 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -1,4 +1,5 @@ from huggingface_hub import hf_hub_download, login, HfApi +from getpass import getpass import os import warnings @@ -18,3 +19,23 @@ def download(repo: str, repo_filename: str, version: str = None): revision=version, token=token, ) + + +def get_or_prompt_hf_token() -> str: + """ + Either get the Hugging Face token from the environment, + or prompt the user for it and store it in the environment. + + Returns: + str: The Hugging Face token. + """ + + token = os.environ.get("HUGGING_FACE_TOKEN") + if token is None: + token = getpass( + "Enter your Hugging Face token (or set HUGGING_FACE_TOKEN environment variable): " + ) + # Optionally store in env for subsequent calls in same session + os.environ["HUGGING_FACE_TOKEN"] = token + + return token diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py new file mode 100644 index 00000000..acaceb30 --- /dev/null +++ b/tests/core/tools/test_hugging_face.py @@ -0,0 +1,61 @@ +import os +import pytest +from unittest.mock import patch +from policyengine_core.tools.hugging_face import get_or_prompt_hf_token + + +def test_get_token_from_environment(): + """Test retrieving token when it exists in environment variables""" + test_token = "test_token_123" + with patch.dict( + os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True + ): + result = get_or_prompt_hf_token() + assert result == test_token + + +def test_get_token_from_user_input(): + """Test retrieving token via user input when not in environment""" + test_token = "user_input_token_456" + + # Mock both empty environment and user input + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + result = get_or_prompt_hf_token() + assert result == test_token + + # Verify token was stored in environment + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token + + +def test_empty_user_input(): + """Test handling of empty user input""" + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", return_value="" + ): + result = get_or_prompt_hf_token() + assert result == "" + assert os.environ.get("HUGGING_FACE_TOKEN") == "" + + +def test_environment_variable_persistence(): + """Test that environment variable persists across multiple calls""" + test_token = "persistence_test_token" + + # First call with no environment variable + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + first_result = get_or_prompt_hf_token() + + # Second call should use environment variable + second_result = get_or_prompt_hf_token() + + assert first_result == second_result == test_token + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token