Skip to content

Commit

Permalink
Add tests for all architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 25, 2024
1 parent 23a99fc commit c8c53e1
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest
import torch
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -40,3 +44,31 @@ def test_prediction_subset_elements():

exported = model.export()
exported([system], evaluation_options, check_consistency=True)


@pytest.mark.parametrize("per_atom", [True, False])
def test_vector_output(per_atom):
"""Tests that the model can predict a (spherical) vector output."""

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"forces": get_generic_target_info(
{
"quantity": "forces",
"unit": "",
"type": {
"spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]}
},
"num_subtargets": 100,
"per_atom": per_atom,
}
)
},
)

with pytest.raises(
ValueError, match="The Alchemical Model only supports total-energy-like outputs"
):
AlchemicalModel(MODEL_HYPERS, dataset_info)
31 changes: 30 additions & 1 deletion src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from metatrain.experimental.gap import GAP, Trainer
from metatrain.utils.data import Dataset, DatasetInfo, read_systems, read_targets
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)

from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS

Expand Down Expand Up @@ -85,3 +88,29 @@ def test_ethanol_regression_train_and_invariance():
val_datasets=[dataset],
checkpoint_dir=".",
)


@pytest.mark.parametrize("per_atom", [True, False])
def test_vector_output(per_atom):
"""Tests that the model can predict a (spherical) vector output."""

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"forces": get_generic_target_info(
{
"quantity": "forces",
"unit": "",
"type": {
"spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]}
},
"num_subtargets": 100,
"per_atom": per_atom,
}
)
},
)

with pytest.raises(ValueError, match="GAP only supports total-energy-like outputs"):
GAP(DEFAULT_HYPERS["model"], dataset_info)
31 changes: 30 additions & 1 deletion src/metatrain/experimental/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from metatrain.experimental.pet import PET as WrappedPET
from metatrain.utils.architectures import get_default_hypers
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)
from metatrain.utils.jsonschema import validate
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
Expand Down Expand Up @@ -222,3 +225,29 @@ def test_selected_atoms_functionality():
evaluation_options,
check_consistency=True,
)


@pytest.mark.parametrize("per_atom", [True, False])
def test_vector_output(per_atom):
"""Tests that the model can predict a (spherical) vector output."""

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"forces": get_generic_target_info(
{
"quantity": "forces",
"unit": "",
"type": {
"spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]}
},
"num_subtargets": 100,
"per_atom": per_atom,
}
)
},
)

with pytest.raises(ValueError, match="PET only supports total-energy-like outputs"):
WrappedPET(DEFAULT_HYPERS["model"], dataset_info)

0 comments on commit c8c53e1

Please sign in to comment.