From 6d9a5b230502c4e79ee8563f508eaed7ce6847d0 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 9 Dec 2024 08:21:33 +0100 Subject: [PATCH] Write tests, fix nanoPET --- .../tests/test_torchscript.py | 20 ++++++ .../gap/tests/test_torchscript.py | 64 +++++++++++++++++++ src/metatrain/experimental/nanopet/model.py | 4 +- .../nanopet/tests/test_torchscript.py | 36 +++++++++++ .../pet/tests/test_torchscript.py | 23 +++++++ .../soap_bpnn/tests/test_torchscript.py | 35 ++++++++++ 6 files changed, 180 insertions(+), 2 deletions(-) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py index fbac2fe6b..af4718b6e 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py @@ -1,3 +1,5 @@ +import copy + import torch from metatrain.experimental.alchemical_model import AlchemicalModel @@ -38,3 +40,21 @@ def test_torchscript_save_load(): ) torch.jit.load("alchemical_model.pt") + + +def test_torchscript_integers(): + """Tests that the model can be jitted when some float + parameters are instead supplied as integers.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={"energy": get_energy_target_info({"unit": "eV"})}, + ) + + new_hypers = copy.deepcopy(MODEL_HYPERS) + new_hypers["soap"]["cutoff"] = 5 + new_hypers["soap"]["basis_scale"] = 3 + + model = AlchemicalModel(MODEL_HYPERS, dataset_info) + torch.jit.script(model, {"energy": model.outputs["energy"]}) diff --git a/src/metatrain/experimental/gap/tests/test_torchscript.py b/src/metatrain/experimental/gap/tests/test_torchscript.py index 04465366d..3293b6f41 100644 --- a/src/metatrain/experimental/gap/tests/test_torchscript.py +++ b/src/metatrain/experimental/gap/tests/test_torchscript.py @@ -1,3 +1,5 @@ +import copy + import torch from omegaconf import OmegaConf @@ -78,3 +80,65 @@ def test_torchscript_save(): "gap.pt", ) torch.jit.load("gap.pt") + + +def test_torchscript_integers(): + """Tests that the model can be jitted when some float + parameters are instead supplied as integers.""" + new_hypers = copy.deepcopy(DEFAULT_HYPERS["model"]) + new_hypers["soap"]["cutoff"] = 5 + new_hypers["soap"]["atomic_gaussian_width"] = 1 + new_hypers["soap"]["center_atom_weight"] = 1 + new_hypers["soap"]["cutoff_function"]["ShiftedCosine"]["width"] = 1 + new_hypers["soap"]["radial_scaling"]["Willatt2018"]["rate"] = 1 + new_hypers["soap"]["radial_scaling"]["Willatt2018"]["scale"] = 2 + new_hypers["soap"]["radial_scaling"]["Willatt2018"]["exponent"] = 7 + + target_info_dict = {} + target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"}) + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict + ) + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": DATASET_PATH, + "reader": "ase", + "key": "U0", + "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": False, + "stress": False, + "virial": False, + } + } + targets, _ = read_targets(OmegaConf.create(conf)) + systems = read_systems(DATASET_PATH) + + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) + + hypers = DEFAULT_HYPERS.copy() + gap = GAP(new_hypers, dataset_info) + trainer = Trainer(hypers["training"]) + trainer.train( + model=gap, + dtype=torch.float64, + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir=".", + ) + scripted_gap = torch.jit.script(gap) + + ref_output = gap.forward(systems[:5], {"mtt::U0": gap.outputs["mtt::U0"]}) + scripted_output = scripted_gap.forward( + systems[:5], {"mtt::U0": gap.outputs["mtt::U0"]} + ) + + assert torch.allclose( + ref_output["mtt::U0"].block().values, + scripted_output["mtt::U0"].block().values, + ) diff --git a/src/metatrain/experimental/nanopet/model.py b/src/metatrain/experimental/nanopet/model.py index 8559aafdf..16e50a981 100644 --- a/src/metatrain/experimental/nanopet/model.py +++ b/src/metatrain/experimental/nanopet/model.py @@ -73,8 +73,8 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: strict=True, ) - self.cutoff = self.hypers["cutoff"] - self.cutoff_width = self.hypers["cutoff_width"] + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_width = float(self.hypers["cutoff_width"]) self.encoder = Encoder(len(self.atomic_types), self.hypers["d_pet"]) diff --git a/src/metatrain/experimental/nanopet/tests/test_torchscript.py b/src/metatrain/experimental/nanopet/tests/test_torchscript.py index 05a05824a..501f56551 100644 --- a/src/metatrain/experimental/nanopet/tests/test_torchscript.py +++ b/src/metatrain/experimental/nanopet/tests/test_torchscript.py @@ -1,3 +1,5 @@ +import copy + import torch from metatensor.torch.atomistic import System @@ -54,3 +56,37 @@ def test_torchscript_save_load(): "model.pt", ) torch.jit.load("model.pt") + + +def test_torchscript_integers(): + """Tests that the model can be jitted when some float + parameters are instead supplied as integers.""" + + new_hypers = copy.deepcopy(MODEL_HYPERS) + new_hypers["cutoff"] = 5 + new_hypers["cutoff_width"] = 1 + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + model = NanoPET(new_hypers, dataset_info) + + system = System( + types=torch.tensor([6, 1, 8, 7]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + model = torch.jit.script(model) + model( + [system], + {"energy": model.outputs["energy"]}, + ) diff --git a/src/metatrain/experimental/pet/tests/test_torchscript.py b/src/metatrain/experimental/pet/tests/test_torchscript.py index 6f5623932..101a37633 100644 --- a/src/metatrain/experimental/pet/tests/test_torchscript.py +++ b/src/metatrain/experimental/pet/tests/test_torchscript.py @@ -1,3 +1,5 @@ +import copy + import torch from pet.hypers import Hypers from pet.pet import PET @@ -44,3 +46,24 @@ def test_torchscript_save_load(): "pet.pt", ) torch.jit.load("pet.pt") + + +def test_torchscript_integers(): + """Tests that the model can be jitted when some float + parameters are instead supplied as integers.""" + + new_hypers = copy.deepcopy(DEFAULT_HYPERS["model"]) + new_hypers["R_CUT"] = 5 + new_hypers["CUTOFF_DELTA"] = 1 + new_hypers["RESIDUAL_FACTOR"] = 1 + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={"energy": get_energy_target_info({"unit": "eV"})}, + ) + model = WrappedPET(new_hypers, dataset_info) + ARCHITECTURAL_HYPERS = Hypers(model.hypers) + raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) + model.set_trained_model(raw_pet) + torch.jit.script(model) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py index 2e16ba26a..d9edddc8a 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py @@ -76,3 +76,38 @@ def test_torchscript_save_load(): "model.pt", ) torch.jit.load("model.pt") + + +def test_torchscript_integers(): + """Tests that the model can be jitted when some float + parameters are instead supplied as integers.""" + + new_hypers = copy.deepcopy(MODEL_HYPERS) + new_hypers["soap"]["cutoff"] = 5 + new_hypers["soap"]["atomic_gaussian_width"] = 1 + new_hypers["soap"]["center_atom_weight"] = 1 + new_hypers["soap"]["cutoff_function"]["ShiftedCosine"]["width"] = 1 + new_hypers["soap"]["radial_scaling"]["Willatt2018"]["rate"] = 1 + new_hypers["soap"]["radial_scaling"]["Willatt2018"]["scale"] = 2 + new_hypers["soap"]["radial_scaling"]["Willatt2018"]["exponent"] = 7 + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={"energy": get_energy_target_info({"unit": "eV"})}, + ) + model = SoapBpnn(new_hypers, dataset_info) + model = torch.jit.script(model) + + system = System( + types=torch.tensor([6, 1, 8, 7]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + model( + [system], + {"energy": model.outputs["energy"]}, + )