Skip to content

Commit

Permalink
Remove the need to specify the architecture when exporting
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 25, 2024
1 parent 09e2550 commit fee73df
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 37 deletions.
4 changes: 3 additions & 1 deletion docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ methods for ``train()``, ``save_checkpoint()`` and ``load_checkpoint()``.
The format of checkpoints is not defined by ``metatrain`` and can be any format that
can be loaded by the trainer (to restart training) and by the model (to export the
checkpoint).
checkpoint). The only requirements are that the checkpoint must be loadable with
``torch.load()``, it must be a dictionary, and it must contain the name of the
architecture under the ``architecture_name`` key.

Init file (``__init__.py``)
---------------------------
Expand Down
6 changes: 3 additions & 3 deletions docs/src/getting-started/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ positional arguments

.. code-block:: bash
mtt export experimental.soap_bpnn model.ckpt -o model.pt
mtt export model.ckpt -o model.pt
or

.. code-block:: bash
mtt export experimental.soap_bpnn model.ckpt --output model.pt
mtt export model.ckpt --output model.pt
For a export of distribution of models the ``export`` command also supports parsing
models from remote locations. To export a remote model you can provide a URL instead of
a file path.

.. code-block:: bash
mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt
mtt export https://my.url.com/model.ckpt --output model.pt
Downloading private HuggingFace models is also supported, by specifying the
corresponding API token with the ``--huggingface_api_token`` flag.
Expand Down
9 changes: 0 additions & 9 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from metatensor.torch.atomistic import is_atomistic_model

from ..utils.architectures import find_all_architectures
from ..utils.io import check_file_extension, load_model
from .formatter import CustomHelpFormatter

Expand All @@ -30,12 +29,6 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
)
parser.set_defaults(callable="export_model")

parser.add_argument(
"architecture_name",
type=str,
choices=find_all_architectures(),
help="name of the model's architecture",
)
parser.add_argument(
"path",
type=str,
Expand Down Expand Up @@ -66,10 +59,8 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
path = args.__dict__.pop("path")
architecture_name = args.__dict__.pop("architecture_name")
args.model = load_model(
path=path,
architecture_name=architecture_name,
**args.__dict__,
)
keys_to_keep = ["model", "output"] # only these are needed for `export_model``
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def train(

def save_checkpoint(self, model, path: Union[str, Path]):
checkpoint = {
"architecture_name": "experimental.alchemical_model",
"model_hypers": {
"model_hypers": model.hypers,
"dataset_info": model.dataset_info,
Expand Down
2 changes: 2 additions & 0 deletions src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
else:
lora_state_dict = None
last_model_checkpoint = {
"architecture_name": "experimental.pet",
"trainer_state_dict": trainer_state_dict,
"model_state_dict": last_model_state_dict,
"best_model_state_dict": self.best_model_state_dict,
Expand All @@ -765,6 +766,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
"lora_state_dict": lora_state_dict,
}
best_model_checkpoint = {
"architecture_name": "experimental.pet",
"trainer_state_dict": None,
"model_state_dict": self.best_model_state_dict,
"best_model_state_dict": None,
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def train(

def save_checkpoint(self, model, path: Union[str, Path]):
checkpoint = {
"architecture_name": "experimental.soap_bpnn",
"model_hypers": {
"model_hypers": model.hypers,
"dataset_info": model.dataset_info,
Expand Down
34 changes: 13 additions & 21 deletions src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from urllib.parse import urlparse
from urllib.request import urlretrieve

import torch
from metatensor.torch.atomistic import check_atomistic_model, load_atomistic_model

from ..utils.architectures import find_all_architectures
from .architectures import import_architecture


Expand Down Expand Up @@ -69,40 +71,26 @@ def is_exported_file(path: str) -> bool:
def load_model(
path: Union[str, Path],
extensions_directory: Optional[Union[str, Path]] = None,
architecture_name: Optional[str] = None,
**kwargs,
) -> Any:
"""Load checkpoints and exported models from an URL or a local file.
If an exported model should be loaded and requires compiled extensions, their
location should be passed using the ``extensions_directory`` parameter.
Loading checkpoints requires the ``architecture_name`` parameter, which can be
ommited for loading an exported model. After reading a checkpoint, the returned
After reading a checkpoint, the returned
model can be exported with the model's own ``export()`` method.
:param path: local or remote path to a model. For supported URL schemes see
:py:class`urllib.request`
:param extensions_directory: path to a directory containing all extensions required
by an *exported* model
:param architecture_name: name of the architecture required for loading from a
*checkpoint*.
:raises ValueError: if both an ``extensions_directory`` and ``architecture_name``
are given
:raises ValueError: if ``path`` is a YAML option file and no model
:raises ValueError: if no ``archietcture_name`` is given for loading a checkpoint
:raises ValueError: if the checkpoint saved in ``path`` does not math the given
``architecture_name``
:raises ValueError: if no ``archietcture_name`` is found in the checkpoint
:raises ValueError: if the ``architecture_name`` is not found in the available
architectures
"""
if extensions_directory is not None and architecture_name is not None:
raise ValueError(
f"Both ``extensions_directory`` ('{str(extensions_directory)}') and "
f"``architecture_name`` ('{architecture_name}') are given which are "
"mutually exclusive. An ``extensions_directory`` is only required for "
"*exported* models while an ``architecture_name`` is only needed for model "
"*checkpoints*."
)

if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(
Expand Down Expand Up @@ -164,10 +152,14 @@ def load_model(
if is_exported_file(path):
return load_atomistic_model(path, extensions_directory=extensions_directory)
else: # model is a checkpoint
if architecture_name is None:
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
if "architecture_name" not in checkpoint:
raise ValueError("No architecture name found in the checkpoint")
architecture_name = checkpoint["architecture_name"]
if architecture_name not in find_all_architectures():
raise ValueError(
f"path '{path}' seems to be a checkpointed model but no "
"`architecture_name` was given"
f"Checkpoint architecture '{architecture_name}' not found "
"in the available architectures"
)
architecture = import_architecture(architecture_name)

Expand Down
2 changes: 1 addition & 1 deletion tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_run_information(capfd, monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)

with pytest.raises(CalledProcessError):
subprocess.check_call(["mtt", "export", "experimental.soap_bpnn", "model.ckpt"])
subprocess.check_call(["mtt", "export", "model.ckpt"])

stdout_log = capfd.readouterr().out

Expand Down
2 changes: 0 additions & 2 deletions tests/cli/test_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_export_cli(monkeypatch, tmp_path, output, dtype):
command = [
"mtt",
"export",
"experimental.soap_bpnn",
str(RESOURCES_PATH / f"model-{dtype_string}-bit.ckpt"),
]

Expand Down Expand Up @@ -129,7 +128,6 @@ def test_private_huggingface(monkeypatch, tmp_path):
command = [
"mtt",
"export",
"experimental.soap_bpnn",
"https://huggingface.co/metatensor/metatrain-test/resolve/main/model.ckpt",
f"--huggingface_api_token={HF_TOKEN}",
]
Expand Down

0 comments on commit fee73df

Please sign in to comment.