Skip to content

Commit

Permalink
Write tests, fix nanoPET
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 9, 2024
1 parent 3a89558 commit 6d9a5b2
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import torch

from metatrain.experimental.alchemical_model import AlchemicalModel
Expand Down Expand Up @@ -38,3 +40,21 @@ def test_torchscript_save_load():
)

torch.jit.load("alchemical_model.pt")


def test_torchscript_integers():
"""Tests that the model can be jitted when some float
parameters are instead supplied as integers."""

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={"energy": get_energy_target_info({"unit": "eV"})},
)

new_hypers = copy.deepcopy(MODEL_HYPERS)
new_hypers["soap"]["cutoff"] = 5
new_hypers["soap"]["basis_scale"] = 3

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
torch.jit.script(model, {"energy": model.outputs["energy"]})
64 changes: 64 additions & 0 deletions src/metatrain/experimental/gap/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import torch
from omegaconf import OmegaConf

Expand Down Expand Up @@ -78,3 +80,65 @@ def test_torchscript_save():
"gap.pt",
)
torch.jit.load("gap.pt")


def test_torchscript_integers():
"""Tests that the model can be jitted when some float
parameters are instead supplied as integers."""
new_hypers = copy.deepcopy(DEFAULT_HYPERS["model"])
new_hypers["soap"]["cutoff"] = 5
new_hypers["soap"]["atomic_gaussian_width"] = 1
new_hypers["soap"]["center_atom_weight"] = 1
new_hypers["soap"]["cutoff_function"]["ShiftedCosine"]["width"] = 1
new_hypers["soap"]["radial_scaling"]["Willatt2018"]["rate"] = 1
new_hypers["soap"]["radial_scaling"]["Willatt2018"]["scale"] = 2
new_hypers["soap"]["radial_scaling"]["Willatt2018"]["exponent"] = 7

target_info_dict = {}
target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"})

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)
conf = {
"mtt::U0": {
"quantity": "energy",
"read_from": DATASET_PATH,
"reader": "ase",
"key": "U0",
"unit": "kcal/mol",
"type": "scalar",
"per_atom": False,
"num_subtargets": 1,
"forces": False,
"stress": False,
"virial": False,
}
}
targets, _ = read_targets(OmegaConf.create(conf))
systems = read_systems(DATASET_PATH)

dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
gap = GAP(new_hypers, dataset_info)
trainer = Trainer(hypers["training"])
trainer.train(
model=gap,
dtype=torch.float64,
devices=[torch.device("cpu")],
train_datasets=[dataset],
val_datasets=[dataset],
checkpoint_dir=".",
)
scripted_gap = torch.jit.script(gap)

ref_output = gap.forward(systems[:5], {"mtt::U0": gap.outputs["mtt::U0"]})
scripted_output = scripted_gap.forward(
systems[:5], {"mtt::U0": gap.outputs["mtt::U0"]}
)

assert torch.allclose(
ref_output["mtt::U0"].block().values,
scripted_output["mtt::U0"].block().values,
)
4 changes: 2 additions & 2 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
strict=True,
)

self.cutoff = self.hypers["cutoff"]
self.cutoff_width = self.hypers["cutoff_width"]
self.cutoff = float(self.hypers["cutoff"])
self.cutoff_width = float(self.hypers["cutoff_width"])

self.encoder = Encoder(len(self.atomic_types), self.hypers["d_pet"])

Expand Down
36 changes: 36 additions & 0 deletions src/metatrain/experimental/nanopet/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import torch
from metatensor.torch.atomistic import System

Expand Down Expand Up @@ -54,3 +56,37 @@ def test_torchscript_save_load():
"model.pt",
)
torch.jit.load("model.pt")


def test_torchscript_integers():
"""Tests that the model can be jitted when some float
parameters are instead supplied as integers."""

new_hypers = copy.deepcopy(MODEL_HYPERS)
new_hypers["cutoff"] = 5
new_hypers["cutoff_width"] = 1

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"energy": get_energy_target_info({"quantity": "energy", "unit": "eV"})
},
)
model = NanoPET(new_hypers, dataset_info)

system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]]
),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)
system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists())

model = torch.jit.script(model)
model(
[system],
{"energy": model.outputs["energy"]},
)
23 changes: 23 additions & 0 deletions src/metatrain/experimental/pet/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import torch
from pet.hypers import Hypers
from pet.pet import PET
Expand Down Expand Up @@ -44,3 +46,24 @@ def test_torchscript_save_load():
"pet.pt",
)
torch.jit.load("pet.pt")


def test_torchscript_integers():
"""Tests that the model can be jitted when some float
parameters are instead supplied as integers."""

new_hypers = copy.deepcopy(DEFAULT_HYPERS["model"])
new_hypers["R_CUT"] = 5
new_hypers["CUTOFF_DELTA"] = 1
new_hypers["RESIDUAL_FACTOR"] = 1

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={"energy": get_energy_target_info({"unit": "eV"})},
)
model = WrappedPET(new_hypers, dataset_info)
ARCHITECTURAL_HYPERS = Hypers(model.hypers)
raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
model.set_trained_model(raw_pet)
torch.jit.script(model)
35 changes: 35 additions & 0 deletions src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,38 @@ def test_torchscript_save_load():
"model.pt",
)
torch.jit.load("model.pt")


def test_torchscript_integers():
"""Tests that the model can be jitted when some float
parameters are instead supplied as integers."""

new_hypers = copy.deepcopy(MODEL_HYPERS)
new_hypers["soap"]["cutoff"] = 5
new_hypers["soap"]["atomic_gaussian_width"] = 1
new_hypers["soap"]["center_atom_weight"] = 1
new_hypers["soap"]["cutoff_function"]["ShiftedCosine"]["width"] = 1
new_hypers["soap"]["radial_scaling"]["Willatt2018"]["rate"] = 1
new_hypers["soap"]["radial_scaling"]["Willatt2018"]["scale"] = 2
new_hypers["soap"]["radial_scaling"]["Willatt2018"]["exponent"] = 7

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={"energy": get_energy_target_info({"unit": "eV"})},
)
model = SoapBpnn(new_hypers, dataset_info)
model = torch.jit.script(model)

system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]]
),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)
model(
[system],
{"energy": model.outputs["energy"]},
)

0 comments on commit 6d9a5b2

Please sign in to comment.