Skip to content

Commit

Permalink
Add huggingface URLs
Browse files Browse the repository at this point in the history
Fixes #309
  • Loading branch information
nikhilwoodruff committed Nov 27, 2024
1 parent 9fbe198 commit 42f6ce2
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 3 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [3.13.0] - 2024-11-27 13:12:44

### Added

- HuggingFace upload/download functionality.

## [3.12.5] - 2024-11-20 13:13:13

### Changed
Expand Down Expand Up @@ -932,6 +938,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0



[3.13.0]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.5...3.13.0
[3.12.5]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.4...3.12.5
[3.12.4]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.3...3.12.4
[3.12.3]: https://github.com/PolicyEngine/policyengine-core/compare/3.12.2...3.12.3
Expand Down
5 changes: 5 additions & 0 deletions changelog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -755,3 +755,8 @@
- update the furo requirment to <2025
- update the markupsafe requirement to <3
date: 2024-11-20 13:13:13
- bump: minor
changes:
added:
- HuggingFace upload/download functionality.
date: 2024-11-27 13:12:44
81 changes: 79 additions & 2 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import requests
import os
import tempfile
from huggingface_hub import HfApi, login, hf_hub_download
import pkg_resources


def atomic_write(file: Path, content: bytes) -> None:
Expand Down Expand Up @@ -53,6 +55,8 @@ class Dataset:
"""The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
url: str = None
"""The URL to download the dataset from. This is used to download the dataset if it does not exist."""
huggingface_url: str = None
"""The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""

# Data formats
TABLES = "tables"
Expand Down Expand Up @@ -306,15 +310,15 @@ def store_file(self, file_path: str):
raise FileNotFoundError(f"File {file_path} does not exist.")
shutil.move(file_path, self.file_path)

def download(self, url: str = None) -> None:
def download(self, url: str = None, version: str = None) -> None:
"""Downloads a file to the dataset's file path.
Args:
url (str): The url to download.
"""

if url is None:
url = self.url
url = self.huggingface_url or self.url

if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ:
auth_headers = {}
Expand Down Expand Up @@ -345,6 +349,10 @@ def download(self, url: str = None) -> None:
raise ValueError(
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
)
elif url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.download_from_huggingface(owner_name, model_name, version)
return
else:
url = url

Expand All @@ -363,6 +371,19 @@ def download(self, url: str = None) -> None:

atomic_write(self.file_path, response.content)

def upload(self, url: str = None):
"""Uploads the dataset to a URL.
Args:
url (str): The url to upload.
"""
if url is None:
url = self.huggingface_url or self.url

if url.startswith("hf://"):
owner_name, model_name = url.split("/")[2:]
self.upload_to_huggingface(owner_name, model_name)

def remove(self):
"""Removes the dataset from disk."""
if self.exists:
Expand Down Expand Up @@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
)()

return dataset

def upload_to_huggingface(self, owner_name: str, model_name: str):
"""Uploads the dataset to Hugging Face.
Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN",
)
login(token=token)
api = HfApi()

# Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
uk_data_version = get_package_version("policyengine-uk-data")
uk_version = get_package_version("policyengine-uk")
with h5py.File(self.file_path, "a") as f:
f.attrs["policyengine-uk-data"] = uk_data_version
f.attrs["policyengine-uk"] = uk_version

api.upload_file(
path_or_fileobj=self.file_path,
path_in_repo=self.file_path.name,
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
)

def download_from_huggingface(
self, owner_name: str, model_name: str, version: str = None
):
"""Downloads the dataset from Hugging Face.
Args:
owner_name (str): The owner name.
model_name (str): The model name.
"""
token = os.environ.get(
"HUGGING_FACE_TOKEN",
)
login(token=token)

hf_hub_download(
repo_id=f"{owner_name}/{model_name}",
repo_type="model",
path=self.file_path,
revision=version,
)


def get_package_version(package_name: str) -> str:
"""Get the installed version of a package."""
try:
return pkg_resources.get_distribution(package_name).version
except pkg_resources.DistributionNotFound:
return "not installed"
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"ipython>=8,<9",
"pyvis>=0.3.2",
"microdf_python>=0.4.3",
"huggingface_hub>=0.25.1",
]

dev_requirements = [
Expand All @@ -48,7 +49,7 @@

setup(
name="policyengine-core",
version="3.12.5",
version="3.13.0",
author="PolicyEngine",
author_email="[email protected]",
classifiers=[
Expand Down

0 comments on commit 42f6ce2

Please sign in to comment.