Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download from HuggingFace with private token #390

Merged
merged 16 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
env:
# Use the CPU only version of torch when building/running the code
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu
HUGGINGFACE_TOKEN_METATRAIN: ${{ secrets.HUGGINGFACE_TOKEN }}

- name: upload to codecov.io
uses: codecov/codecov-action@v4
Expand Down
3 changes: 3 additions & 0 deletions docs/src/getting-started/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ a file path.

mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt

Downloading private HuggingFace models is also supported, by specifying the
corresponding API token with the ``--huggingface_api_token`` flag.

Keep in mind that a checkpoint (``.ckpt``) is only a temporary file, which can have
several dependencies and may become unusable if the corresponding architecture is
updated. In constrast, exported models (``.pt``) act as standalone files. For long-term
Expand Down
20 changes: 18 additions & 2 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,30 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
default="exported-model.pt",
help="Filename of the exported model (default: %(default)s).",
)
parser.add_argument(
"--huggingface_api_token",
dest="huggingface_api_token",
type=str,
required=False,
default="",
help="API token to download a private model from HuggingFace.",
)


def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
path = args.__dict__.pop("path")
architecture_name = args.__dict__.pop("architecture_name")
args.model = load_model(
path=args.__dict__.pop("path"),
architecture_name=args.__dict__.pop("architecture_name"),
path=path,
architecture_name=architecture_name,
**args.__dict__,
)
keys_to_keep = ["model", "output"] # only these are needed for `export_model``
original_keys = list(args.__dict__.keys())
for key in original_keys:
if key not in keys_to_keep:
args.__dict__.pop(key)


def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None:
Expand Down
69 changes: 62 additions & 7 deletions src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import shutil
import warnings
from pathlib import Path
from typing import Any, Optional, Union
Expand All @@ -9,6 +11,9 @@
from .architectures import import_architecture


logger = logging.getLogger(__name__)


def check_file_extension(
filename: Union[str, Path], extension: str
) -> Union[str, Path]:
Expand Down Expand Up @@ -65,6 +70,7 @@
path: Union[str, Path],
extensions_directory: Optional[Union[str, Path]] = None,
architecture_name: Optional[str] = None,
**kwargs,
) -> Any:
"""Load checkpoints and exported models from an URL or a local file.

Expand Down Expand Up @@ -99,15 +105,64 @@
)

if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(f"path '{path}' seems to be a YAML option file and no model")
raise ValueError(
f"path '{path}' seems to be a YAML option file and not a model"
)

if urlparse(str(path)).scheme:
# Download from HuggingFace with a private token
if kwargs.get("huggingface_api_token"):
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(

Check warning on line 117 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L116-L117

Added lines #L116 - L117 were not covered by tests
"To download a model from HuggingFace, please install the "
"`huggingface_hub` package with pip (`pip install "
"huggingface_hub`)."
)
path = str(path)
if not path.startswith("https://huggingface.co/"):
raise ValueError(

Check warning on line 124 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L124

Added line #L124 was not covered by tests
f"Invalid URL '{path}'. HuggingFace models should start with "
"'https://huggingface.co/'."
)
# get repo_id and filename
split_path = path.split("/")
repo_id = f"{split_path[3]}/{split_path[4]}" # org/repo
filename = ""
Comment on lines +128 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing string manipulation, should we pass more structured information to this code? Copied from my comment on the overall PR:

maybe a better interface here would look like mtt export --mode huggingface --hf-api-token=.... --hf-owner=who/what --hf-file=file-path (very bad syntax, just for the idea. This could also be huggingface://who/what/file-path or huggingface://token@who/what/file-path or whatever).

I.e. instead of trying to guess the info we need, ask the use about it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because HuggingFace gives you the download link, so that's what most users will use

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asking for the "split information" on the command line is just more work for the user

for i in range(5, len(split_path)):
filename += split_path[i] + "/"
filename = filename[:-1]
if filename.startswith("resolve"):
if not filename[8:].startswith("main/"):
raise ValueError(

Check warning on line 137 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L137

Added line #L137 was not covered by tests
f"Invalid URL '{path}'. metatrain only supports models from the "
"'main' branch."
)
filename = filename[13:]
if filename.startswith("blob/"):
if not filename[5:].startswith("main/"):
raise ValueError(

Check warning on line 144 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L143-L144

Added lines #L143 - L144 were not covered by tests
f"Invalid URL '{path}'. metatrain only supports models from the "
"'main' branch."
)
filename = filename[10:]

Check warning on line 148 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L148

Added line #L148 was not covered by tests
path = hf_hub_download(repo_id, filename, token=kwargs["huggingface_api_token"])
# make sure to copy the checkpoint to the current directory
shutil.copy(path, Path.cwd() / filename)
logger.info(f"Downloaded model from HuggingFace to {filename}")
Comment on lines +151 to +152
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand the code properly, this will download repo/owner/.../main/some/path/to/file.ckpt to <cwd>/some/path/to/file.ckpt. Is this intended? Otherwise we should be using os.path.basename(filename) as the destination


elif urlparse(str(path)).scheme:
path, _ = urlretrieve(str(path))
# make sure to copy the checkpoint to the current directory
shutil.copy(path, Path.cwd() / str(path).split("/")[-1])
Comment on lines +156 to +157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand why this is required?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if you want the file to be directly downloaded to the current directory you can do

path, _ = urlretrieve(
    url=str(path),
    filename=str(Path.cwd() / str(path).split("/")[-1]))

See https://docs.python.org/3/library/urllib.request.html#urllib.request.urlretrieve.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required for fine-tuning. mtt export only makes the .pt available in the current directory, but that's useless for fine-tuning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to use filename in urlretrieve, and add the same INFO message about adding a file to cwd?

logger.info(f"Downloaded model to {str(path).split('/')[-1]}")

if is_exported_file(str(path)):
return load_atomistic_model(
str(path), extensions_directory=extensions_directory
)
else:
pass

path = str(path)
if is_exported_file(path):
return load_atomistic_model(path, extensions_directory=extensions_directory)
else: # model is a checkpoint
if architecture_name is None:
raise ValueError(
Expand All @@ -117,7 +172,7 @@
architecture = import_architecture(architecture_name)

try:
return architecture.__model__.load_checkpoint(str(path))
return architecture.__model__.load_checkpoint(path)
except Exception as err:
raise ValueError(
f"path '{path}' is not a valid model file for the {architecture_name} "
Expand Down
41 changes: 41 additions & 0 deletions tests/cli/test_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import glob
import logging
import os
import subprocess
from pathlib import Path

import huggingface_hub
import pytest
import torch

Expand Down Expand Up @@ -104,3 +106,42 @@ def test_reexport(monkeypatch, tmp_path):
export_model(model_loaded, "exported_new.pt")

assert Path("exported_new.pt").is_file()


def test_private_huggingface(monkeypatch, tmp_path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail for people running the tests locally, unless they have access to the repo & a token set, right?

Could we skip the test if HUGGINGFACE_TOKEN_METATRAIN is unset? And set it on CI/locally as needed. (using a different name in case people are already using HUGGINGFACE_TOKEN for some other reason.

"""Test that the export cli succeeds when exporting a private
model from HuggingFace."""
monkeypatch.chdir(tmp_path)

HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN_METATRAIN")
if HF_TOKEN is None:
pytest.skip("HuggingFace token not found in environment.")
assert len(HF_TOKEN) > 0

huggingface_hub.upload_file(
path_or_fileobj=str(RESOURCES_PATH / "model-32-bit.ckpt"),
path_in_repo="model.ckpt",
repo_id="metatensor/metatrain-test",
commit_message="Overwrite test model with new version",
token=HF_TOKEN,
)

command = [
"mtt",
"export",
"experimental.soap_bpnn",
"https://huggingface.co/metatensor/metatrain-test/resolve/main/model.ckpt",
f"--huggingface_api_token={HF_TOKEN}",
]

output = "exported-model.pt"

subprocess.check_call(command)
assert Path(output).is_file()

# Test if extensions are saved
extensions_glob = glob.glob("extensions/")
assert len(extensions_glob) == 1

# Test that the model can be loaded
load_model(output, extensions_directory="extensions/")
5 changes: 4 additions & 1 deletion tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def test_is_exported_file():
def test_load_model_checkpoint(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
assert type(model) is SoapBpnn
if str(path).startswith("file:"):
# test that the checkpoint is also copied to the current directory
assert Path("model-32-bit.ckpt").exists()


@pytest.mark.parametrize(
Expand All @@ -64,7 +67,7 @@ def test_load_model_exported(path):

@pytest.mark.parametrize("suffix", [".yml", ".yaml"])
def test_load_model_yaml(suffix):
match = f"path 'foo{suffix}' seems to be a YAML option file and no model"
match = f"path 'foo{suffix}' seems to be a YAML option file and not a model"
with pytest.raises(ValueError, match=match):
load_model(
f"foo{suffix}",
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ deps =
pytest
pytest-cov
pytest-xdist
huggingface_hub
changedir = tests
extras = # architectures used in the package tests
soap-bpnn
Expand Down