Skip to content

Commit

Permalink
Merge pull request #316 from anth-volk/fix/refactor-dataset
Browse files Browse the repository at this point in the history
Fix HuggingFace uploads and downloads
  • Loading branch information
anth-volk authored Dec 2, 2024
2 parents 7410356 + f728390 commit 155b01a
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 18 deletions.
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: patch
changes:
changed:
- Replaced coexistent standard/HuggingFace URL with standalone URL parameter
- Fixed bugs in download_from_huggingface() method
- Create utility function for pulling HuggingFace env var
52 changes: 34 additions & 18 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import tempfile
from policyengine_core.tools.hugging_face import *
import sys


def atomic_write(file: Path, content: bytes) -> None:
Expand Down Expand Up @@ -54,8 +55,6 @@ class Dataset:
"""The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
url: str = None
"""The URL to download the dataset from. This is used to download the dataset if it does not exist."""
huggingface_url: str = None
"""The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""

# Data formats
TABLES = "tables"
Expand Down Expand Up @@ -317,7 +316,7 @@ def download(self, url: str = None, version: str = None) -> None:
"""

if url is None:
url = self.url or self.huggingface_url
url = self.url

if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
auth_headers = {}
Expand Down Expand Up @@ -349,8 +348,10 @@ def download(self, url: str = None, version: str = None) -> None:
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
)
elif url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.download_from_huggingface(owner_name, model_name, version)
owner_name, model_name, file_name = url.split("/")[2:]
self.download_from_huggingface(
owner_name, model_name, file_name, version
)
return
else:
url = url
Expand All @@ -377,11 +378,11 @@ def upload(self, url: str = None):
url (str): The url to upload.
"""
if url is None:
url = self.huggingface_url or self.url
url = self.url

if url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.upload_to_huggingface(owner_name, model_name)
owner_name, model_name, file_name = url.split("/")[2:]
self.upload_to_huggingface(owner_name, model_name, file_name)

def remove(self):
"""Removes the dataset from disk."""
Expand Down Expand Up @@ -451,43 +452,58 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):

return dataset

def upload_to_huggingface(self, owner_name: str, model_name: str):
"""Uploads the dataset to Hugging Face.
def upload_to_huggingface(
self, owner_name: str, model_name: str, file_name: str
):
"""Uploads the dataset to HuggingFace.
Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN",

print(
f"Uploading to HuggingFace {owner_name}/{model_name}/{file_name}",
file=sys.stderr,
)

token = get_or_prompt_hf_token()
api = HfApi()

api.upload_file(
path_or_fileobj=self.file_path,
path_in_repo=self.file_path.name,
path_in_repo=file_name,
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
token=token,
)

def download_from_huggingface(
self, owner_name: str, model_name: str, version: str = None
self,
owner_name: str,
model_name: str,
file_name: str,
version: str = None,
):
"""Downloads the dataset from Hugging Face.
"""Downloads the dataset from HuggingFace.
Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN",

print(
f"Downloading from HuggingFace {owner_name}/{model_name}/{file_name}",
file=sys.stderr,
)

token = get_or_prompt_hf_token()

hf_hub_download(
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
path=self.file_path,
filename=file_name,
local_dir=self.file_path.parent,
revision=version,
token=token,
)
21 changes: 21 additions & 0 deletions policyengine_core/tools/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from huggingface_hub import hf_hub_download, login, HfApi
from getpass import getpass
import os
import warnings

Expand All @@ -18,3 +19,23 @@ def download(repo: str, repo_filename: str, version: str = None):
revision=version,
token=token,
)


def get_or_prompt_hf_token() -> str:
"""
Either get the Hugging Face token from the environment,
or prompt the user for it and store it in the environment.
Returns:
str: The Hugging Face token.
"""

token = os.environ.get("HUGGING_FACE_TOKEN")
if token is None:
token = getpass(
"Enter your Hugging Face token (or set HUGGING_FACE_TOKEN environment variable): "
)
# Optionally store in env for subsequent calls in same session
os.environ["HUGGING_FACE_TOKEN"] = token

return token
61 changes: 61 additions & 0 deletions tests/core/tools/test_hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import pytest
from unittest.mock import patch
from policyengine_core.tools.hugging_face import get_or_prompt_hf_token


def test_get_token_from_environment():
"""Test retrieving token when it exists in environment variables"""
test_token = "test_token_123"
with patch.dict(
os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True
):
result = get_or_prompt_hf_token()
assert result == test_token


def test_get_token_from_user_input():
"""Test retrieving token via user input when not in environment"""
test_token = "user_input_token_456"

# Mock both empty environment and user input
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
result = get_or_prompt_hf_token()
assert result == test_token

# Verify token was stored in environment
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token


def test_empty_user_input():
"""Test handling of empty user input"""
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass", return_value=""
):
result = get_or_prompt_hf_token()
assert result == ""
assert os.environ.get("HUGGING_FACE_TOKEN") == ""


def test_environment_variable_persistence():
"""Test that environment variable persists across multiple calls"""
test_token = "persistence_test_token"

# First call with no environment variable
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
first_result = get_or_prompt_hf_token()

# Second call should use environment variable
second_result = get_or_prompt_hf_token()

assert first_result == second_result == test_token
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token

0 comments on commit 155b01a

Please sign in to comment.