Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HuggingFace uploads and downloads #316

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: patch
changes:
changed:
- Replaced coexistent standard/HuggingFace URL with standalone URL parameter
- Fixed bugs in download_from_huggingface() method
- Create utility function for pulling HuggingFace env var
52 changes: 34 additions & 18 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import tempfile
from policyengine_core.tools.hugging_face import *
import sys


def atomic_write(file: Path, content: bytes) -> None:
Expand Down Expand Up @@ -54,8 +55,6 @@ class Dataset:
"""The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
url: str = None
"""The URL to download the dataset from. This is used to download the dataset if it does not exist."""
huggingface_url: str = None
"""The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""

# Data formats
TABLES = "tables"
Expand Down Expand Up @@ -317,7 +316,7 @@ def download(self, url: str = None, version: str = None) -> None:
"""

if url is None:
url = self.url or self.huggingface_url
url = self.url

if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
auth_headers = {}
Expand Down Expand Up @@ -349,8 +348,10 @@ def download(self, url: str = None, version: str = None) -> None:
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
)
elif url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.download_from_huggingface(owner_name, model_name, version)
owner_name, model_name, file_name = url.split("/")[2:]
self.download_from_huggingface(
owner_name, model_name, file_name, version
)
return
else:
url = url
Expand All @@ -377,11 +378,11 @@ def upload(self, url: str = None):
url (str): The url to upload.
"""
if url is None:
url = self.huggingface_url or self.url
url = self.url

if url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.upload_to_huggingface(owner_name, model_name)
owner_name, model_name, file_name = url.split("/")[2:]
self.upload_to_huggingface(owner_name, model_name, file_name)

def remove(self):
"""Removes the dataset from disk."""
Expand Down Expand Up @@ -451,43 +452,58 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):

return dataset

def upload_to_huggingface(self, owner_name: str, model_name: str):
"""Uploads the dataset to Hugging Face.
def upload_to_huggingface(
self, owner_name: str, model_name: str, file_name: str
):
"""Uploads the dataset to HuggingFace.

Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN",

print(
f"Uploading to HuggingFace {owner_name}/{model_name}/{file_name}",
file=sys.stderr,
)

token = get_or_prompt_hf_token()
api = HfApi()

api.upload_file(
path_or_fileobj=self.file_path,
path_in_repo=self.file_path.name,
path_in_repo=file_name,
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
token=token,
)

def download_from_huggingface(
self, owner_name: str, model_name: str, version: str = None
self,
owner_name: str,
model_name: str,
file_name: str,
version: str = None,
):
"""Downloads the dataset from Hugging Face.
"""Downloads the dataset from HuggingFace.

Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN",

print(
f"Downloading from HuggingFace {owner_name}/{model_name}/{file_name}",
file=sys.stderr,
)

token = get_or_prompt_hf_token()

hf_hub_download(
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
path=self.file_path,
filename=file_name,
local_dir=self.file_path.parent,
revision=version,
token=token,
)
21 changes: 21 additions & 0 deletions policyengine_core/tools/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from huggingface_hub import hf_hub_download, login, HfApi
from getpass import getpass
import os
import warnings

Expand All @@ -18,3 +19,23 @@ def download(repo: str, repo_filename: str, version: str = None):
revision=version,
token=token,
)


def get_or_prompt_hf_token() -> str:
"""
Either get the Hugging Face token from the environment,
or prompt the user for it and store it in the environment.

Returns:
str: The Hugging Face token.
"""

token = os.environ.get("HUGGING_FACE_TOKEN")
if token is None:
token = getpass(
"Enter your Hugging Face token (or set HUGGING_FACE_TOKEN environment variable): "
)
# Optionally store in env for subsequent calls in same session
os.environ["HUGGING_FACE_TOKEN"] = token

return token
61 changes: 61 additions & 0 deletions tests/core/tools/test_hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import pytest
from unittest.mock import patch
from policyengine_core.tools.hugging_face import get_or_prompt_hf_token


def test_get_token_from_environment():
"""Test retrieving token when it exists in environment variables"""
test_token = "test_token_123"
with patch.dict(
os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True
):
result = get_or_prompt_hf_token()
assert result == test_token


def test_get_token_from_user_input():
"""Test retrieving token via user input when not in environment"""
test_token = "user_input_token_456"

# Mock both empty environment and user input
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
result = get_or_prompt_hf_token()
assert result == test_token

# Verify token was stored in environment
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token


def test_empty_user_input():
"""Test handling of empty user input"""
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass", return_value=""
):
result = get_or_prompt_hf_token()
assert result == ""
assert os.environ.get("HUGGING_FACE_TOKEN") == ""


def test_environment_variable_persistence():
"""Test that environment variable persists across multiple calls"""
test_token = "persistence_test_token"

# First call with no environment variable
with patch.dict(os.environ, {}, clear=True):
with patch(
"policyengine_core.tools.hugging_face.getpass",
return_value=test_token,
):
first_result = get_or_prompt_hf_token()

# Second call should use environment variable
second_result = get_or_prompt_hf_token()

assert first_result == second_result == test_token
assert os.environ.get("HUGGING_FACE_TOKEN") == test_token