Skip to content

Commit

Permalink
Remove the need to specify the architecture when exporting (#405)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Nov 26, 2024
1 parent 8861eaf commit 63a3ddb
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 64 deletions.
4 changes: 3 additions & 1 deletion docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ methods for ``train()``, ``save_checkpoint()`` and ``load_checkpoint()``.
The format of checkpoints is not defined by ``metatrain`` and can be any format that
can be loaded by the trainer (to restart training) and by the model (to export the
checkpoint).
checkpoint). The only requirements are that the checkpoint must be loadable with
``torch.load()``, it must be a dictionary, and it must contain the name of the
architecture under the ``architecture_name`` key.

Init file (``__init__.py``)
---------------------------
Expand Down
6 changes: 3 additions & 3 deletions docs/src/getting-started/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ positional arguments

.. code-block:: bash
mtt export experimental.soap_bpnn model.ckpt -o model.pt
mtt export model.ckpt -o model.pt
or

.. code-block:: bash
mtt export experimental.soap_bpnn model.ckpt --output model.pt
mtt export model.ckpt --output model.pt
For a export of distribution of models the ``export`` command also supports parsing
models from remote locations. To export a remote model you can provide a URL instead of
a file path.

.. code-block:: bash
mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt
mtt export https://my.url.com/model.ckpt --output model.pt
Downloading private HuggingFace models is also supported, by specifying the
corresponding API token with the ``--huggingface_api_token`` flag.
Expand Down
9 changes: 0 additions & 9 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from metatensor.torch.atomistic import is_atomistic_model

from ..utils.architectures import find_all_architectures
from ..utils.io import check_file_extension, load_model
from .formatter import CustomHelpFormatter

Expand All @@ -30,12 +29,6 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
)
parser.set_defaults(callable="export_model")

parser.add_argument(
"architecture_name",
type=str,
choices=find_all_architectures(),
help="name of the model's architecture",
)
parser.add_argument(
"path",
type=str,
Expand Down Expand Up @@ -66,10 +59,8 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
path = args.__dict__.pop("path")
architecture_name = args.__dict__.pop("architecture_name")
args.model = load_model(
path=path,
architecture_name=architecture_name,
**args.__dict__,
)
keys_to_keep = ["model", "output"] # only these are needed for `export_model``
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def train(

def save_checkpoint(self, model, path: Union[str, Path]):
checkpoint = {
"architecture_name": "experimental.alchemical_model",
"model_hypers": {
"model_hypers": model.hypers,
"dataset_info": model.dataset_info,
Expand Down
2 changes: 2 additions & 0 deletions src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
else:
lora_state_dict = None
last_model_checkpoint = {
"architecture_name": "experimental.pet",
"trainer_state_dict": trainer_state_dict,
"model_state_dict": last_model_state_dict,
"best_model_state_dict": self.best_model_state_dict,
Expand All @@ -765,6 +766,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
"lora_state_dict": lora_state_dict,
}
best_model_checkpoint = {
"architecture_name": "experimental.pet",
"trainer_state_dict": None,
"model_state_dict": self.best_model_state_dict,
"best_model_state_dict": None,
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def train(

def save_checkpoint(self, model, path: Union[str, Path]):
checkpoint = {
"architecture_name": "experimental.soap_bpnn",
"model_hypers": {
"model_hypers": model.hypers,
"dataset_info": model.dataset_info,
Expand Down
37 changes: 15 additions & 22 deletions src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from urllib.parse import urlparse
from urllib.request import urlretrieve

import torch
from metatensor.torch.atomistic import check_atomistic_model, load_atomistic_model

from ..utils.architectures import find_all_architectures
from .architectures import import_architecture


Expand Down Expand Up @@ -70,40 +72,26 @@ def is_exported_file(path: str) -> bool:
def load_model(
path: Union[str, Path],
extensions_directory: Optional[Union[str, Path]] = None,
architecture_name: Optional[str] = None,
**kwargs,
) -> Any:
"""Load checkpoints and exported models from an URL or a local file.
If an exported model should be loaded and requires compiled extensions, their
location should be passed using the ``extensions_directory`` parameter.
Loading checkpoints requires the ``architecture_name`` parameter, which can be
ommited for loading an exported model. After reading a checkpoint, the returned
After reading a checkpoint, the returned
model can be exported with the model's own ``export()`` method.
:param path: local or remote path to a model. For supported URL schemes see
:py:class`urllib.request`
:param extensions_directory: path to a directory containing all extensions required
by an *exported* model
:param architecture_name: name of the architecture required for loading from a
*checkpoint*.
:raises ValueError: if both an ``extensions_directory`` and ``architecture_name``
are given
:raises ValueError: if ``path`` is a YAML option file and no model
:raises ValueError: if no ``archietcture_name`` is given for loading a checkpoint
:raises ValueError: if the checkpoint saved in ``path`` does not math the given
``architecture_name``
:raises ValueError: if no ``archietcture_name`` is found in the checkpoint
:raises ValueError: if the ``architecture_name`` is not found in the available
architectures
"""
if extensions_directory is not None and architecture_name is not None:
raise ValueError(
f"Both ``extensions_directory`` ('{str(extensions_directory)}') and "
f"``architecture_name`` ('{architecture_name}') are given which are "
"mutually exclusive. An ``extensions_directory`` is only required for "
"*exported* models while an ``architecture_name`` is only needed for model "
"*checkpoints*."
)

if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(
Expand Down Expand Up @@ -167,17 +155,22 @@ def load_model(
if is_exported_file(path):
return load_atomistic_model(path, extensions_directory=extensions_directory)
else: # model is a checkpoint
if architecture_name is None:
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
if "architecture_name" not in checkpoint:
raise ValueError("No architecture name found in the checkpoint")
architecture_name = checkpoint["architecture_name"]
if architecture_name not in find_all_architectures():
raise ValueError(
f"path '{path}' seems to be a checkpointed model but no "
"`architecture_name` was given"
f"Checkpoint architecture '{architecture_name}' not found "
"in the available architectures. Available architectures are: "
f"{find_all_architectures()}"
)
architecture = import_architecture(architecture_name)

try:
return architecture.__model__.load_checkpoint(path)
except Exception as err:
raise ValueError(
f"path '{path}' is not a valid model file for the {architecture_name} "
f"path '{path}' is not a valid checkpoint for the {architecture_name} "
"architecture"
) from err
6 changes: 2 additions & 4 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,11 @@ def test_run_information(capfd, monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)

with pytest.raises(CalledProcessError):
subprocess.check_call(["mtt", "export", "experimental.soap_bpnn", "model.ckpt"])
subprocess.check_call(["mtt", "export", "model.ckpt"])

stdout_log = capfd.readouterr().out

assert f"Package directory: {PACKAGE_ROOT}" in stdout_log
assert f"Working directory: {Path('.').absolute()}" in stdout_log
assert f"Metatrain version: {__version__}" in stdout_log
assert (
"Executed command: mtt export experimental.soap_bpnn model.ckpt" in stdout_log
)
assert "Executed command: mtt export model.ckpt" in stdout_log
15 changes: 9 additions & 6 deletions tests/cli/test_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_export_cli(monkeypatch, tmp_path, output, dtype):
command = [
"mtt",
"export",
"experimental.soap_bpnn",
str(RESOURCES_PATH / f"model-{dtype_string}-bit.ckpt"),
]

Expand All @@ -81,12 +80,17 @@ def test_export_cli(monkeypatch, tmp_path, output, dtype):
assert next(model.parameters()).device.type == "cpu"


def test_export_cli_architecture_names_choices():
stderr = str(subprocess.run(["mtt", "export", "foo"], capture_output=True).stderr)
def test_export_cli_unknown_architecture(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
torch.save({"architecture_name": "foo"}, "fake.ckpt")

stdout = str(
subprocess.run(["mtt", "export", "fake.ckpt"], capture_output=True).stdout
)

assert "invalid choice: 'foo'" in stderr
assert "architecture 'foo' not found in the available architectures" in stdout
for architecture_name in find_all_architectures():
assert architecture_name in stderr
assert architecture_name in stdout


def test_reexport(monkeypatch, tmp_path):
Expand Down Expand Up @@ -129,7 +133,6 @@ def test_private_huggingface(monkeypatch, tmp_path):
command = [
"mtt",
"export",
"experimental.soap_bpnn",
"https://huggingface.co/metatensor/metatrain-test/resolve/main/model.ckpt",
f"--huggingface_api_token={HF_TOKEN}",
]
Expand Down
36 changes: 18 additions & 18 deletions tests/utils/test_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import pytest
import torch
from metatensor.torch.atomistic import MetatensorAtomisticModel

from metatrain.experimental.soap_bpnn.model import SoapBpnn
Expand Down Expand Up @@ -45,7 +46,7 @@ def test_is_exported_file():
],
)
def test_load_model_checkpoint(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
model = load_model(path)
assert type(model) is SoapBpnn
if str(path).startswith("file:"):
# test that the checkpoint is also copied to the current directory
Expand All @@ -61,38 +62,37 @@ def test_load_model_checkpoint(path):
],
)
def test_load_model_exported(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
model = load_model(path)
assert type(model) is MetatensorAtomisticModel


@pytest.mark.parametrize("suffix", [".yml", ".yaml"])
def test_load_model_yaml(suffix):
match = f"path 'foo{suffix}' seems to be a YAML option file and not a model"
with pytest.raises(ValueError, match=match):
load_model(
f"foo{suffix}",
architecture_name="experimental.soap_bpnn",
)
load_model(f"foo{suffix}")


def test_load_model_unknown_model():
architecture_name = "experimental.pet"
path = RESOURCES_PATH / "model-32-bit.ckpt"
def test_load_model_unknown_model(monkeypatch, tmpdir):
monkeypatch.chdir(tmpdir)
architecture_name = "experimental.soap_bpnn"
path = "fake.ckpt"
torch.save({"architecture_name": architecture_name}, path)

match = (
f"path '{path}' is not a valid model file for the {architecture_name} "
f"path '{path}' is not a valid checkpoint for the {architecture_name} "
"architecture"
)
with pytest.raises(ValueError, match=match):
load_model(path, architecture_name=architecture_name)


def test_extensions_directory_and_architecture_name():
match = (
r"Both ``extensions_directory`` \('.'\) and ``architecture_name`` \('foo'\) "
r"are given which are mutually exclusive. An ``extensions_directory`` is only "
r"required for \*exported\* models while an ``architecture_name`` is only "
r"needed for model \*checkpoints\*."
)
def test_load_model_no_architecture_name(monkeypatch, tmpdir):
monkeypatch.chdir(tmpdir)
architecture_name = "experimental.soap_bpnn"
path = "fake.ckpt"
torch.save({"not_architecture_name": architecture_name}, path)

match = "No architecture name found in the checkpoint"
with pytest.raises(ValueError, match=match):
load_model("model.pt", extensions_directory=".", architecture_name="foo")
load_model(path, architecture_name=architecture_name)
1 change: 0 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ deps =
changedir = tests
extras = # architectures used in the package tests
soap-bpnn
pet
allowlist_externals = bash
commands_pre = bash {toxinidir}/tests/resources/generate-outputs.sh
commands =
Expand Down

0 comments on commit 63a3ddb

Please sign in to comment.