Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditionally pass Hugging Face tokens for download when repo is private #321

Merged
merged 6 commits into from
Dec 18, 2024

Conversation

anth-volk
Copy link
Contributor

@anth-volk anth-volk commented Dec 7, 2024

Fixes #320.

This code does three things:

  • Unifies all downloads from Hugging Face into the same function, which the Dataset class merely calls under the hood, to avoid code duplication and accord with DRY
  • Modifies said function to first request model_info from Hugging Face. If the model info indicates that the repo is public, the code merely downloads relevant files, whereas if it's private, the function asks the user to input a Hugging Face token before proceeding.
  • Adds tests for this new functionality.

A couple things here:

  • Hugging Face seems to have no way to find out if a repo is private or public, and the model_info endpoint raises an error if the repo is private. This code attempts to capitalize on that, but it's inherently messy and means that we're using a raised error to signal "not private," even though at times it means "repo doesn't exist"
  • I'd love code quality commentary. Something just doesn't feel concise about this code.

Requesting both Max and Nikhil, as Nikhil's touched this, but Max asked about the functionality.

Copy link
Contributor

@MaxGhenis MaxGhenis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! Here are suggestions from https://pr-improver.streamlit.app that I haven't reviewed:

Here are 6 specific suggestions to improve this PR:

  1. File: policyengine_core/tools/hugging_face.py
    Lines: 17-18
    Change:

    def hf_download(repo: str, repo_filename: str, version: str = None):

    to:

    def download_huggingface_dataset(repository: str, file_name: str, version: str = None):

    Explanation: More descriptive function name and parameter names improve clarity, especially for non-native English speakers.

  2. File: policyengine_core/tools/hugging_face.py
    Lines: 31-34
    Add comment:

    # Attempt to fetch model info to determine if repo is private
    # A RepositoryNotFoundError likely means the repo is private and requires authentication
    try:
        fetched_model_info: ModelInfo = model_info(repository)
        is_repo_private: bool = fetched_model_info.private
    except RepositoryNotFoundError:
        is_repo_private = True

    Explanation: This comment clarifies the purpose of the try-except block and explains the logic behind assuming a private repository.

  3. File: policyengine_core/tools/hugging_face.py
    Lines: 46-50
    Change:

    token: str = None
    if is_repo_private:
        token: str = get_or_prompt_hf_token()

    to:

    authentication_token: str = None
    if is_repo_private:
        authentication_token: str = get_or_prompt_huggingface_token()

    Explanation: More descriptive variable names and function name improve clarity.

  4. File: policyengine_core/simulations/simulation.py
    Lines: 161-163
    Change:

    dataset = hf_download(
        owner + "/" + repo, filename, version
    )

    to:

    dataset = download_huggingface_dataset(
        repository=f"{owner}/{repo}",
        file_name=filename,
        version=version
    )

    Explanation: Use the new function name and provide named arguments for better readability.

  5. File: tests/core/tools/test_hugging_face.py
    Lines: 13-14
    Change class name:

    class TestHfDownload:

    to:

    class TestHuggingFaceDownload:

    Explanation: More descriptive class name improves clarity of test purpose.

  6. File: tests/core/tools/test_hugging_face.py
    Lines: 70-92
    Add a new test case:

    def test_download_nonexistent_repo(self):
        """Test handling of a nonexistent repository"""
        test_repo = "nonexistent_repo"
        test_filename = "test_filename"
    
        with patch("policyengine_core.tools.hugging_face.model_info") as mock_model_info:
            mock_model_info.side_effect = Exception("Repository not found")
    
            with pytest.raises(Exception) as exc_info:
                download_huggingface_dataset(test_repo, test_filename)
    
            assert "Unable to download dataset" in str(exc_info.value)

    Explanation: This new test case improves coverage by checking the handling of nonexistent repositories, which is different from private repositories.

These suggestions focus on improving code clarity, adding helpful comments, and enhancing test coverage, which should make the code more maintainable and easier to understand for all contributors.

@anth-volk
Copy link
Contributor Author

Wow, this is interesting! Thanks for sending these suggestions. Will incorporate and re-request review.

@anth-volk anth-volk requested a review from MaxGhenis December 7, 2024 09:22
@anth-volk
Copy link
Contributor Author

Review re-requested. I haven't added the new test from above, as Hugging Face raises the same error for both a nonexistent repo and a private repo without a passed token, and thus the test would be the exact same.

@MaxGhenis MaxGhenis merged commit 54b1f18 into PolicyEngine:master Dec 18, 2024
3 checks passed
@anth-volk anth-volk deleted the fix/publicize-huggingface branch December 19, 2024 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optionally pass token to hf_hub_download
3 participants