Skip to content

Commit

Permalink
fix: Change dataset download func name
Browse files Browse the repository at this point in the history
  • Loading branch information
anth-volk committed Dec 7, 2024
1 parent a7b1e51 commit d9a418b
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def download_from_huggingface(
file=sys.stderr,
)

hf_download(
download_huggingface_dataset(
repo=f"{owner_name}/{model_name}",
repo_filename=file_name,
version=version,
Expand Down
2 changes: 1 addition & 1 deletion policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
filename = filename.split("@")[0]
else:
version = None
dataset = hf_download(
dataset = download_huggingface_dataset(
owner + "/" + repo, filename, version
)
datasets_by_name = {
Expand Down
2 changes: 1 addition & 1 deletion policyengine_core/tools/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
warnings.simplefilter("ignore")


def hf_download(repo: str, repo_filename: str, version: str = None):
def download_huggingface_dataset(repo: str, repo_filename: str, version: str = None):
"""
Download a dataset from the Hugging Face Hub.
Expand Down
8 changes: 4 additions & 4 deletions tests/core/tools/test_hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from huggingface_hub.errors import RepositoryNotFoundError
from policyengine_core.tools.hugging_face import (
get_or_prompt_hf_token,
hf_download,
download_huggingface_dataset,
)


Expand All @@ -28,7 +28,7 @@ def test_download_public_repo(self):
id=test_id, private=False
)

hf_download(test_repo, test_filename, test_version)
download_huggingface_dataset(test_repo, test_filename, test_version)

mock_download.assert_called_with(
repo_id=test_repo,
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_download_private_repo(self):
) as mock_token:
mock_token.return_value = "test_token"

hf_download(test_repo, test_filename, test_version)
download_huggingface_dataset(test_repo, test_filename, test_version)
mock_download.assert_called_with(
repo_id=test_repo,
repo_type="model",
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_download_private_repo_no_token(self):
mock_token.return_value = ""

with pytest.raises(Exception):
hf_download(test_repo, test_filename, test_version)
download_huggingface_dataset(test_repo, test_filename, test_version)
mock_download.assert_not_called()


Expand Down

0 comments on commit d9a418b

Please sign in to comment.