From 63a3ddba902ca2a33c55c0d464a580ac1e5e5ad8 Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Tue, 26 Nov 2024 09:00:04 +0100 Subject: [PATCH] Remove the need to specify the architecture when exporting (#405) --- docs/src/dev-docs/new-architecture.rst | 4 +- docs/src/getting-started/checkpoints.rst | 6 +-- src/metatrain/cli/export.py | 9 ----- .../experimental/alchemical_model/trainer.py | 1 + src/metatrain/experimental/pet/trainer.py | 2 + .../experimental/soap_bpnn/trainer.py | 1 + src/metatrain/utils/io.py | 37 ++++++++----------- tests/cli/test_cli.py | 6 +-- tests/cli/test_export_model.py | 15 +++++--- tests/utils/test_io.py | 36 +++++++++--------- tox.ini | 1 - 11 files changed, 54 insertions(+), 64 deletions(-) diff --git a/docs/src/dev-docs/new-architecture.rst b/docs/src/dev-docs/new-architecture.rst index 68b9df0d9..062846582 100644 --- a/docs/src/dev-docs/new-architecture.rst +++ b/docs/src/dev-docs/new-architecture.rst @@ -160,7 +160,9 @@ methods for ``train()``, ``save_checkpoint()`` and ``load_checkpoint()``. The format of checkpoints is not defined by ``metatrain`` and can be any format that can be loaded by the trainer (to restart training) and by the model (to export the -checkpoint). +checkpoint). The only requirements are that the checkpoint must be loadable with +``torch.load()``, it must be a dictionary, and it must contain the name of the +architecture under the ``architecture_name`` key. Init file (``__init__.py``) --------------------------- diff --git a/docs/src/getting-started/checkpoints.rst b/docs/src/getting-started/checkpoints.rst index e5c2dfaae..050753277 100644 --- a/docs/src/getting-started/checkpoints.rst +++ b/docs/src/getting-started/checkpoints.rst @@ -30,13 +30,13 @@ positional arguments .. code-block:: bash - mtt export experimental.soap_bpnn model.ckpt -o model.pt + mtt export model.ckpt -o model.pt or .. code-block:: bash - mtt export experimental.soap_bpnn model.ckpt --output model.pt + mtt export model.ckpt --output model.pt For a export of distribution of models the ``export`` command also supports parsing models from remote locations. To export a remote model you can provide a URL instead of @@ -44,7 +44,7 @@ a file path. .. code-block:: bash - mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt + mtt export https://my.url.com/model.ckpt --output model.pt Downloading private HuggingFace models is also supported, by specifying the corresponding API token with the ``--huggingface_api_token`` flag. diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index a27b617e4..c025c6989 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -5,7 +5,6 @@ from metatensor.torch.atomistic import is_atomistic_model -from ..utils.architectures import find_all_architectures from ..utils.io import check_file_extension, load_model from .formatter import CustomHelpFormatter @@ -30,12 +29,6 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: ) parser.set_defaults(callable="export_model") - parser.add_argument( - "architecture_name", - type=str, - choices=find_all_architectures(), - help="name of the model's architecture", - ) parser.add_argument( "path", type=str, @@ -66,10 +59,8 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: def _prepare_export_model_args(args: argparse.Namespace) -> None: """Prepare arguments for export_model.""" path = args.__dict__.pop("path") - architecture_name = args.__dict__.pop("architecture_name") args.model = load_model( path=path, - architecture_name=architecture_name, **args.__dict__, ) keys_to_keep = ["model", "output"] # only these are needed for `export_model`` diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 611c1c2e4..db1caa825 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -371,6 +371,7 @@ def train( def save_checkpoint(self, model, path: Union[str, Path]): checkpoint = { + "architecture_name": "experimental.alchemical_model", "model_hypers": { "model_hypers": model.hypers, "dataset_info": model.dataset_info, diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index a74f5050f..69fe9576e 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -754,6 +754,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): else: lora_state_dict = None last_model_checkpoint = { + "architecture_name": "experimental.pet", "trainer_state_dict": trainer_state_dict, "model_state_dict": last_model_state_dict, "best_model_state_dict": self.best_model_state_dict, @@ -765,6 +766,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): "lora_state_dict": lora_state_dict, } best_model_checkpoint = { + "architecture_name": "experimental.pet", "trainer_state_dict": None, "model_state_dict": self.best_model_state_dict, "best_model_state_dict": None, diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 563fc00d6..151e46316 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -411,6 +411,7 @@ def train( def save_checkpoint(self, model, path: Union[str, Path]): checkpoint = { + "architecture_name": "experimental.soap_bpnn", "model_hypers": { "model_hypers": model.hypers, "dataset_info": model.dataset_info, diff --git a/src/metatrain/utils/io.py b/src/metatrain/utils/io.py index fea744079..d68247fac 100644 --- a/src/metatrain/utils/io.py +++ b/src/metatrain/utils/io.py @@ -7,8 +7,10 @@ from urllib.parse import urlparse from urllib.request import urlretrieve +import torch from metatensor.torch.atomistic import check_atomistic_model, load_atomistic_model +from ..utils.architectures import find_all_architectures from .architectures import import_architecture @@ -70,7 +72,6 @@ def is_exported_file(path: str) -> bool: def load_model( path: Union[str, Path], extensions_directory: Optional[Union[str, Path]] = None, - architecture_name: Optional[str] = None, **kwargs, ) -> Any: """Load checkpoints and exported models from an URL or a local file. @@ -78,32 +79,19 @@ def load_model( If an exported model should be loaded and requires compiled extensions, their location should be passed using the ``extensions_directory`` parameter. - Loading checkpoints requires the ``architecture_name`` parameter, which can be - ommited for loading an exported model. After reading a checkpoint, the returned + After reading a checkpoint, the returned model can be exported with the model's own ``export()`` method. :param path: local or remote path to a model. For supported URL schemes see :py:class`urllib.request` :param extensions_directory: path to a directory containing all extensions required by an *exported* model - :param architecture_name: name of the architecture required for loading from a - *checkpoint*. - :raises ValueError: if both an ``extensions_directory`` and ``architecture_name`` - are given :raises ValueError: if ``path`` is a YAML option file and no model - :raises ValueError: if no ``archietcture_name`` is given for loading a checkpoint - :raises ValueError: if the checkpoint saved in ``path`` does not math the given - ``architecture_name`` + :raises ValueError: if no ``archietcture_name`` is found in the checkpoint + :raises ValueError: if the ``architecture_name`` is not found in the available + architectures """ - if extensions_directory is not None and architecture_name is not None: - raise ValueError( - f"Both ``extensions_directory`` ('{str(extensions_directory)}') and " - f"``architecture_name`` ('{architecture_name}') are given which are " - "mutually exclusive. An ``extensions_directory`` is only required for " - "*exported* models while an ``architecture_name`` is only needed for model " - "*checkpoints*." - ) if Path(path).suffix in [".yaml", ".yml"]: raise ValueError( @@ -167,10 +155,15 @@ def load_model( if is_exported_file(path): return load_atomistic_model(path, extensions_directory=extensions_directory) else: # model is a checkpoint - if architecture_name is None: + checkpoint = torch.load(path, weights_only=False, map_location="cpu") + if "architecture_name" not in checkpoint: + raise ValueError("No architecture name found in the checkpoint") + architecture_name = checkpoint["architecture_name"] + if architecture_name not in find_all_architectures(): raise ValueError( - f"path '{path}' seems to be a checkpointed model but no " - "`architecture_name` was given" + f"Checkpoint architecture '{architecture_name}' not found " + "in the available architectures. Available architectures are: " + f"{find_all_architectures()}" ) architecture = import_architecture(architecture_name) @@ -178,6 +171,6 @@ def load_model( return architecture.__model__.load_checkpoint(path) except Exception as err: raise ValueError( - f"path '{path}' is not a valid model file for the {architecture_name} " + f"path '{path}' is not a valid checkpoint for the {architecture_name} " "architecture" ) from err diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 112497b7f..990e07421 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -149,13 +149,11 @@ def test_run_information(capfd, monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) with pytest.raises(CalledProcessError): - subprocess.check_call(["mtt", "export", "experimental.soap_bpnn", "model.ckpt"]) + subprocess.check_call(["mtt", "export", "model.ckpt"]) stdout_log = capfd.readouterr().out assert f"Package directory: {PACKAGE_ROOT}" in stdout_log assert f"Working directory: {Path('.').absolute()}" in stdout_log assert f"Metatrain version: {__version__}" in stdout_log - assert ( - "Executed command: mtt export experimental.soap_bpnn model.ckpt" in stdout_log - ) + assert "Executed command: mtt export model.ckpt" in stdout_log diff --git a/tests/cli/test_export_model.py b/tests/cli/test_export_model.py index 8c005b5f2..9c57ed2f8 100644 --- a/tests/cli/test_export_model.py +++ b/tests/cli/test_export_model.py @@ -57,7 +57,6 @@ def test_export_cli(monkeypatch, tmp_path, output, dtype): command = [ "mtt", "export", - "experimental.soap_bpnn", str(RESOURCES_PATH / f"model-{dtype_string}-bit.ckpt"), ] @@ -81,12 +80,17 @@ def test_export_cli(monkeypatch, tmp_path, output, dtype): assert next(model.parameters()).device.type == "cpu" -def test_export_cli_architecture_names_choices(): - stderr = str(subprocess.run(["mtt", "export", "foo"], capture_output=True).stderr) +def test_export_cli_unknown_architecture(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + torch.save({"architecture_name": "foo"}, "fake.ckpt") + + stdout = str( + subprocess.run(["mtt", "export", "fake.ckpt"], capture_output=True).stdout + ) - assert "invalid choice: 'foo'" in stderr + assert "architecture 'foo' not found in the available architectures" in stdout for architecture_name in find_all_architectures(): - assert architecture_name in stderr + assert architecture_name in stdout def test_reexport(monkeypatch, tmp_path): @@ -129,7 +133,6 @@ def test_private_huggingface(monkeypatch, tmp_path): command = [ "mtt", "export", - "experimental.soap_bpnn", "https://huggingface.co/metatensor/metatrain-test/resolve/main/model.ckpt", f"--huggingface_api_token={HF_TOKEN}", ] diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index a864bf7e3..8294dad1c 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest +import torch from metatensor.torch.atomistic import MetatensorAtomisticModel from metatrain.experimental.soap_bpnn.model import SoapBpnn @@ -45,7 +46,7 @@ def test_is_exported_file(): ], ) def test_load_model_checkpoint(path): - model = load_model(path, architecture_name="experimental.soap_bpnn") + model = load_model(path) assert type(model) is SoapBpnn if str(path).startswith("file:"): # test that the checkpoint is also copied to the current directory @@ -61,7 +62,7 @@ def test_load_model_checkpoint(path): ], ) def test_load_model_exported(path): - model = load_model(path, architecture_name="experimental.soap_bpnn") + model = load_model(path) assert type(model) is MetatensorAtomisticModel @@ -69,30 +70,29 @@ def test_load_model_exported(path): def test_load_model_yaml(suffix): match = f"path 'foo{suffix}' seems to be a YAML option file and not a model" with pytest.raises(ValueError, match=match): - load_model( - f"foo{suffix}", - architecture_name="experimental.soap_bpnn", - ) + load_model(f"foo{suffix}") -def test_load_model_unknown_model(): - architecture_name = "experimental.pet" - path = RESOURCES_PATH / "model-32-bit.ckpt" +def test_load_model_unknown_model(monkeypatch, tmpdir): + monkeypatch.chdir(tmpdir) + architecture_name = "experimental.soap_bpnn" + path = "fake.ckpt" + torch.save({"architecture_name": architecture_name}, path) match = ( - f"path '{path}' is not a valid model file for the {architecture_name} " + f"path '{path}' is not a valid checkpoint for the {architecture_name} " "architecture" ) with pytest.raises(ValueError, match=match): load_model(path, architecture_name=architecture_name) -def test_extensions_directory_and_architecture_name(): - match = ( - r"Both ``extensions_directory`` \('.'\) and ``architecture_name`` \('foo'\) " - r"are given which are mutually exclusive. An ``extensions_directory`` is only " - r"required for \*exported\* models while an ``architecture_name`` is only " - r"needed for model \*checkpoints\*." - ) +def test_load_model_no_architecture_name(monkeypatch, tmpdir): + monkeypatch.chdir(tmpdir) + architecture_name = "experimental.soap_bpnn" + path = "fake.ckpt" + torch.save({"not_architecture_name": architecture_name}, path) + + match = "No architecture name found in the checkpoint" with pytest.raises(ValueError, match=match): - load_model("model.pt", extensions_directory=".", architecture_name="foo") + load_model(path, architecture_name=architecture_name) diff --git a/tox.ini b/tox.ini index 9694d7fba..b8aa89680 100644 --- a/tox.ini +++ b/tox.ini @@ -62,7 +62,6 @@ deps = changedir = tests extras = # architectures used in the package tests soap-bpnn - pet allowlist_externals = bash commands_pre = bash {toxinidir}/tests/resources/generate-outputs.sh commands =