diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2c4db005c..173321cd2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -46,6 +46,7 @@ jobs: env: # Use the CPU only version of torch when building/running the code PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu + HUGGINGFACE_TOKEN_METATRAIN: ${{ secrets.HUGGINGFACE_TOKEN }} - name: upload to codecov.io uses: codecov/codecov-action@v4 diff --git a/docs/src/getting-started/checkpoints.rst b/docs/src/getting-started/checkpoints.rst index 05cd5f3e8..e5c2dfaae 100644 --- a/docs/src/getting-started/checkpoints.rst +++ b/docs/src/getting-started/checkpoints.rst @@ -46,6 +46,9 @@ a file path. mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt +Downloading private HuggingFace models is also supported, by specifying the +corresponding API token with the ``--huggingface_api_token`` flag. + Keep in mind that a checkpoint (``.ckpt``) is only a temporary file, which can have several dependencies and may become unusable if the corresponding architecture is updated. In constrast, exported models (``.pt``) act as standalone files. For long-term diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index a375a5e13..a27b617e4 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -53,14 +53,30 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: default="exported-model.pt", help="Filename of the exported model (default: %(default)s).", ) + parser.add_argument( + "--huggingface_api_token", + dest="huggingface_api_token", + type=str, + required=False, + default="", + help="API token to download a private model from HuggingFace.", + ) def _prepare_export_model_args(args: argparse.Namespace) -> None: """Prepare arguments for export_model.""" + path = args.__dict__.pop("path") + architecture_name = args.__dict__.pop("architecture_name") args.model = load_model( - path=args.__dict__.pop("path"), - architecture_name=args.__dict__.pop("architecture_name"), + path=path, + architecture_name=architecture_name, + **args.__dict__, ) + keys_to_keep = ["model", "output"] # only these are needed for `export_model`` + original_keys = list(args.__dict__.keys()) + for key in original_keys: + if key not in keys_to_keep: + args.__dict__.pop(key) def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None: diff --git a/src/metatrain/utils/io.py b/src/metatrain/utils/io.py index bb3c046fe..440bfdc1c 100644 --- a/src/metatrain/utils/io.py +++ b/src/metatrain/utils/io.py @@ -1,3 +1,5 @@ +import logging +import shutil import warnings from pathlib import Path from typing import Any, Optional, Union @@ -9,6 +11,9 @@ from .architectures import import_architecture +logger = logging.getLogger(__name__) + + def check_file_extension( filename: Union[str, Path], extension: str ) -> Union[str, Path]: @@ -65,6 +70,7 @@ def load_model( path: Union[str, Path], extensions_directory: Optional[Union[str, Path]] = None, architecture_name: Optional[str] = None, + **kwargs, ) -> Any: """Load checkpoints and exported models from an URL or a local file. @@ -99,15 +105,64 @@ def load_model( ) if Path(path).suffix in [".yaml", ".yml"]: - raise ValueError(f"path '{path}' seems to be a YAML option file and no model") + raise ValueError( + f"path '{path}' seems to be a YAML option file and not a model" + ) - if urlparse(str(path)).scheme: + # Download from HuggingFace with a private token + if kwargs.get("huggingface_api_token"): + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise ImportError( + "To download a model from HuggingFace, please install the " + "`huggingface_hub` package with pip (`pip install " + "huggingface_hub`)." + ) + path = str(path) + if not path.startswith("https://huggingface.co/"): + raise ValueError( + f"Invalid URL '{path}'. HuggingFace models should start with " + "'https://huggingface.co/'." + ) + # get repo_id and filename + split_path = path.split("/") + repo_id = f"{split_path[3]}/{split_path[4]}" # org/repo + filename = "" + for i in range(5, len(split_path)): + filename += split_path[i] + "/" + filename = filename[:-1] + if filename.startswith("resolve"): + if not filename[8:].startswith("main/"): + raise ValueError( + f"Invalid URL '{path}'. metatrain only supports models from the " + "'main' branch." + ) + filename = filename[13:] + if filename.startswith("blob/"): + if not filename[5:].startswith("main/"): + raise ValueError( + f"Invalid URL '{path}'. metatrain only supports models from the " + "'main' branch." + ) + filename = filename[10:] + path = hf_hub_download(repo_id, filename, token=kwargs["huggingface_api_token"]) + # make sure to copy the checkpoint to the current directory + shutil.copy(path, Path.cwd() / filename) + logger.info(f"Downloaded model from HuggingFace to {filename}") + + elif urlparse(str(path)).scheme: path, _ = urlretrieve(str(path)) + # make sure to copy the checkpoint to the current directory + shutil.copy(path, Path.cwd() / str(path).split("/")[-1]) + logger.info(f"Downloaded model to {str(path).split('/')[-1]}") - if is_exported_file(str(path)): - return load_atomistic_model( - str(path), extensions_directory=extensions_directory - ) + else: + pass + + path = str(path) + if is_exported_file(path): + return load_atomistic_model(path, extensions_directory=extensions_directory) else: # model is a checkpoint if architecture_name is None: raise ValueError( @@ -117,7 +172,7 @@ def load_model( architecture = import_architecture(architecture_name) try: - return architecture.__model__.load_checkpoint(str(path)) + return architecture.__model__.load_checkpoint(path) except Exception as err: raise ValueError( f"path '{path}' is not a valid model file for the {architecture_name} " diff --git a/tests/cli/test_export_model.py b/tests/cli/test_export_model.py index 10feab713..8c005b5f2 100644 --- a/tests/cli/test_export_model.py +++ b/tests/cli/test_export_model.py @@ -5,9 +5,11 @@ import glob import logging +import os import subprocess from pathlib import Path +import huggingface_hub import pytest import torch @@ -104,3 +106,42 @@ def test_reexport(monkeypatch, tmp_path): export_model(model_loaded, "exported_new.pt") assert Path("exported_new.pt").is_file() + + +def test_private_huggingface(monkeypatch, tmp_path): + """Test that the export cli succeeds when exporting a private + model from HuggingFace.""" + monkeypatch.chdir(tmp_path) + + HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN_METATRAIN") + if HF_TOKEN is None: + pytest.skip("HuggingFace token not found in environment.") + assert len(HF_TOKEN) > 0 + + huggingface_hub.upload_file( + path_or_fileobj=str(RESOURCES_PATH / "model-32-bit.ckpt"), + path_in_repo="model.ckpt", + repo_id="metatensor/metatrain-test", + commit_message="Overwrite test model with new version", + token=HF_TOKEN, + ) + + command = [ + "mtt", + "export", + "experimental.soap_bpnn", + "https://huggingface.co/metatensor/metatrain-test/resolve/main/model.ckpt", + f"--huggingface_api_token={HF_TOKEN}", + ] + + output = "exported-model.pt" + + subprocess.check_call(command) + assert Path(output).is_file() + + # Test if extensions are saved + extensions_glob = glob.glob("extensions/") + assert len(extensions_glob) == 1 + + # Test that the model can be loaded + load_model(output, extensions_directory="extensions/") diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index e50a425a9..a864bf7e3 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -47,6 +47,9 @@ def test_is_exported_file(): def test_load_model_checkpoint(path): model = load_model(path, architecture_name="experimental.soap_bpnn") assert type(model) is SoapBpnn + if str(path).startswith("file:"): + # test that the checkpoint is also copied to the current directory + assert Path("model-32-bit.ckpt").exists() @pytest.mark.parametrize( @@ -64,7 +67,7 @@ def test_load_model_exported(path): @pytest.mark.parametrize("suffix", [".yml", ".yaml"]) def test_load_model_yaml(suffix): - match = f"path 'foo{suffix}' seems to be a YAML option file and no model" + match = f"path 'foo{suffix}' seems to be a YAML option file and not a model" with pytest.raises(ValueError, match=match): load_model( f"foo{suffix}", diff --git a/tox.ini b/tox.ini index 8302cbc6b..9694d7fba 100644 --- a/tox.ini +++ b/tox.ini @@ -58,6 +58,7 @@ deps = pytest pytest-cov pytest-xdist + huggingface_hub changedir = tests extras = # architectures used in the package tests soap-bpnn