diff --git a/docs/src/dev-docs/new-architecture.rst b/docs/src/dev-docs/new-architecture.rst index 68b9df0d9..062846582 100644 --- a/docs/src/dev-docs/new-architecture.rst +++ b/docs/src/dev-docs/new-architecture.rst @@ -160,7 +160,9 @@ methods for ``train()``, ``save_checkpoint()`` and ``load_checkpoint()``. The format of checkpoints is not defined by ``metatrain`` and can be any format that can be loaded by the trainer (to restart training) and by the model (to export the -checkpoint). +checkpoint). The only requirements are that the checkpoint must be loadable with +``torch.load()``, it must be a dictionary, and it must contain the name of the +architecture under the ``architecture_name`` key. Init file (``__init__.py``) --------------------------- diff --git a/docs/src/getting-started/checkpoints.rst b/docs/src/getting-started/checkpoints.rst index e5c2dfaae..050753277 100644 --- a/docs/src/getting-started/checkpoints.rst +++ b/docs/src/getting-started/checkpoints.rst @@ -30,13 +30,13 @@ positional arguments .. code-block:: bash - mtt export experimental.soap_bpnn model.ckpt -o model.pt + mtt export model.ckpt -o model.pt or .. code-block:: bash - mtt export experimental.soap_bpnn model.ckpt --output model.pt + mtt export model.ckpt --output model.pt For a export of distribution of models the ``export`` command also supports parsing models from remote locations. To export a remote model you can provide a URL instead of @@ -44,7 +44,7 @@ a file path. .. code-block:: bash - mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt + mtt export https://my.url.com/model.ckpt --output model.pt Downloading private HuggingFace models is also supported, by specifying the corresponding API token with the ``--huggingface_api_token`` flag. diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index a27b617e4..c025c6989 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -5,7 +5,6 @@ from metatensor.torch.atomistic import is_atomistic_model -from ..utils.architectures import find_all_architectures from ..utils.io import check_file_extension, load_model from .formatter import CustomHelpFormatter @@ -30,12 +29,6 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: ) parser.set_defaults(callable="export_model") - parser.add_argument( - "architecture_name", - type=str, - choices=find_all_architectures(), - help="name of the model's architecture", - ) parser.add_argument( "path", type=str, @@ -66,10 +59,8 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: def _prepare_export_model_args(args: argparse.Namespace) -> None: """Prepare arguments for export_model.""" path = args.__dict__.pop("path") - architecture_name = args.__dict__.pop("architecture_name") args.model = load_model( path=path, - architecture_name=architecture_name, **args.__dict__, ) keys_to_keep = ["model", "output"] # only these are needed for `export_model`` diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 611c1c2e4..db1caa825 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -371,6 +371,7 @@ def train( def save_checkpoint(self, model, path: Union[str, Path]): checkpoint = { + "architecture_name": "experimental.alchemical_model", "model_hypers": { "model_hypers": model.hypers, "dataset_info": model.dataset_info, diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index a74f5050f..69fe9576e 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -754,6 +754,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): else: lora_state_dict = None last_model_checkpoint = { + "architecture_name": "experimental.pet", "trainer_state_dict": trainer_state_dict, "model_state_dict": last_model_state_dict, "best_model_state_dict": self.best_model_state_dict, @@ -765,6 +766,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): "lora_state_dict": lora_state_dict, } best_model_checkpoint = { + "architecture_name": "experimental.pet", "trainer_state_dict": None, "model_state_dict": self.best_model_state_dict, "best_model_state_dict": None, diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 563fc00d6..151e46316 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -411,6 +411,7 @@ def train( def save_checkpoint(self, model, path: Union[str, Path]): checkpoint = { + "architecture_name": "experimental.soap_bpnn", "model_hypers": { "model_hypers": model.hypers, "dataset_info": model.dataset_info, diff --git a/src/metatrain/utils/io.py b/src/metatrain/utils/io.py index 440bfdc1c..f87e6885d 100644 --- a/src/metatrain/utils/io.py +++ b/src/metatrain/utils/io.py @@ -6,8 +6,10 @@ from urllib.parse import urlparse from urllib.request import urlretrieve +import torch from metatensor.torch.atomistic import check_atomistic_model, load_atomistic_model +from ..utils.architectures import find_all_architectures from .architectures import import_architecture @@ -69,7 +71,6 @@ def is_exported_file(path: str) -> bool: def load_model( path: Union[str, Path], extensions_directory: Optional[Union[str, Path]] = None, - architecture_name: Optional[str] = None, **kwargs, ) -> Any: """Load checkpoints and exported models from an URL or a local file. @@ -77,32 +78,19 @@ def load_model( If an exported model should be loaded and requires compiled extensions, their location should be passed using the ``extensions_directory`` parameter. - Loading checkpoints requires the ``architecture_name`` parameter, which can be - ommited for loading an exported model. After reading a checkpoint, the returned + After reading a checkpoint, the returned model can be exported with the model's own ``export()`` method. :param path: local or remote path to a model. For supported URL schemes see :py:class`urllib.request` :param extensions_directory: path to a directory containing all extensions required by an *exported* model - :param architecture_name: name of the architecture required for loading from a - *checkpoint*. - :raises ValueError: if both an ``extensions_directory`` and ``architecture_name`` - are given :raises ValueError: if ``path`` is a YAML option file and no model - :raises ValueError: if no ``archietcture_name`` is given for loading a checkpoint - :raises ValueError: if the checkpoint saved in ``path`` does not math the given - ``architecture_name`` + :raises ValueError: if no ``archietcture_name`` is found in the checkpoint + :raises ValueError: if the ``architecture_name`` is not found in the available + architectures """ - if extensions_directory is not None and architecture_name is not None: - raise ValueError( - f"Both ``extensions_directory`` ('{str(extensions_directory)}') and " - f"``architecture_name`` ('{architecture_name}') are given which are " - "mutually exclusive. An ``extensions_directory`` is only required for " - "*exported* models while an ``architecture_name`` is only needed for model " - "*checkpoints*." - ) if Path(path).suffix in [".yaml", ".yml"]: raise ValueError( @@ -164,10 +152,14 @@ def load_model( if is_exported_file(path): return load_atomistic_model(path, extensions_directory=extensions_directory) else: # model is a checkpoint - if architecture_name is None: + checkpoint = torch.load(path, weights_only=False, map_location="cpu") + if "architecture_name" not in checkpoint: + raise ValueError("No architecture name found in the checkpoint") + architecture_name = checkpoint["architecture_name"] + if architecture_name not in find_all_architectures(): raise ValueError( - f"path '{path}' seems to be a checkpointed model but no " - "`architecture_name` was given" + f"Checkpoint architecture '{architecture_name}' not found " + "in the available architectures" ) architecture = import_architecture(architecture_name) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 112497b7f..3606d3d76 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -149,7 +149,7 @@ def test_run_information(capfd, monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) with pytest.raises(CalledProcessError): - subprocess.check_call(["mtt", "export", "experimental.soap_bpnn", "model.ckpt"]) + subprocess.check_call(["mtt", "export", "model.ckpt"]) stdout_log = capfd.readouterr().out diff --git a/tests/cli/test_export_model.py b/tests/cli/test_export_model.py index 8c005b5f2..fd69ada71 100644 --- a/tests/cli/test_export_model.py +++ b/tests/cli/test_export_model.py @@ -57,7 +57,6 @@ def test_export_cli(monkeypatch, tmp_path, output, dtype): command = [ "mtt", "export", - "experimental.soap_bpnn", str(RESOURCES_PATH / f"model-{dtype_string}-bit.ckpt"), ] @@ -129,7 +128,6 @@ def test_private_huggingface(monkeypatch, tmp_path): command = [ "mtt", "export", - "experimental.soap_bpnn", "https://huggingface.co/metatensor/metatrain-test/resolve/main/model.ckpt", f"--huggingface_api_token={HF_TOKEN}", ]