Skip to content

Commit

Permalink
Adapt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 25, 2024
1 parent fee73df commit 6adfd2e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 27 deletions.
7 changes: 5 additions & 2 deletions src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def load_model(
architectures
"""

print("LOADING MODEL")

if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(
f"path '{path}' seems to be a YAML option file and not a model"
Expand Down Expand Up @@ -159,14 +161,15 @@ def load_model(
if architecture_name not in find_all_architectures():
raise ValueError(
f"Checkpoint architecture '{architecture_name}' not found "
"in the available architectures"
"in the available architectures. Available architectures are: "
f"{find_all_architectures()}"
)
architecture = import_architecture(architecture_name)

try:
return architecture.__model__.load_checkpoint(path)
except Exception as err:
raise ValueError(
f"path '{path}' is not a valid model file for the {architecture_name} "
f"path '{path}' is not a valid checkpoint for the {architecture_name} "
"architecture"
) from err
13 changes: 9 additions & 4 deletions tests/cli/test_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,17 @@ def test_export_cli(monkeypatch, tmp_path, output, dtype):
assert next(model.parameters()).device.type == "cpu"


def test_export_cli_architecture_names_choices():
stderr = str(subprocess.run(["mtt", "export", "foo"], capture_output=True).stderr)
def test_export_cli_unknown_architecture(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
torch.save({"architecture_name": "foo"}, "fake.ckpt")

stdout = str(
subprocess.run(["mtt", "export", "fake.ckpt"], capture_output=True).stdout
)

assert "invalid choice: 'foo'" in stderr
assert "architecture 'foo' not found in the available architectures" in stdout
for architecture_name in find_all_architectures():
assert architecture_name in stderr
assert architecture_name in stdout


def test_reexport(monkeypatch, tmp_path):
Expand Down
31 changes: 10 additions & 21 deletions tests/utils/test_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import pytest
import torch
from metatensor.torch.atomistic import MetatensorAtomisticModel

from metatrain.experimental.soap_bpnn.model import SoapBpnn
Expand Down Expand Up @@ -45,7 +46,7 @@ def test_is_exported_file():
],
)
def test_load_model_checkpoint(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
model = load_model(path)
assert type(model) is SoapBpnn
if str(path).startswith("file:"):
# test that the checkpoint is also copied to the current directory
Expand All @@ -61,38 +62,26 @@ def test_load_model_checkpoint(path):
],
)
def test_load_model_exported(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
model = load_model(path)
assert type(model) is MetatensorAtomisticModel


@pytest.mark.parametrize("suffix", [".yml", ".yaml"])
def test_load_model_yaml(suffix):
match = f"path 'foo{suffix}' seems to be a YAML option file and not a model"
with pytest.raises(ValueError, match=match):
load_model(
f"foo{suffix}",
architecture_name="experimental.soap_bpnn",
)
load_model(f"foo{suffix}")


def test_load_model_unknown_model():
architecture_name = "experimental.pet"
path = RESOURCES_PATH / "model-32-bit.ckpt"
def test_load_model_unknown_model(monkeypatch, tmpdir):
monkeypatch.chdir(tmpdir)
architecture_name = "experimental.soap_bpnn"
path = "fake.ckpt"
torch.save({"architecture_name": architecture_name}, path)

match = (
f"path '{path}' is not a valid model file for the {architecture_name} "
f"path '{path}' is not a valid checkpoint for the {architecture_name} "
"architecture"
)
with pytest.raises(ValueError, match=match):
load_model(path, architecture_name=architecture_name)


def test_extensions_directory_and_architecture_name():
match = (
r"Both ``extensions_directory`` \('.'\) and ``architecture_name`` \('foo'\) "
r"are given which are mutually exclusive. An ``extensions_directory`` is only "
r"required for \*exported\* models while an ``architecture_name`` is only "
r"needed for model \*checkpoints\*."
)
with pytest.raises(ValueError, match=match):
load_model("model.pt", extensions_directory=".", architecture_name="foo")

0 comments on commit 6adfd2e

Please sign in to comment.