diff --git a/src/metatrain/experimental/nanopet/model.py b/src/metatrain/experimental/nanopet/model.py index 9452d395c..8559aafdf 100644 --- a/src/metatrain/experimental/nanopet/model.py +++ b/src/metatrain/experimental/nanopet/model.py @@ -67,6 +67,12 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self.new_outputs = list(dataset_info.targets.keys()) self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.hypers["cutoff"], + full_list=True, + strict=True, + ) + self.cutoff = self.hypers["cutoff"] self.cutoff_width = self.hypers["cutoff_width"] @@ -250,7 +256,7 @@ def forward( species, cells, cell_shifts, - ) = concatenate_structures(systems) + ) = concatenate_structures(systems, self.requested_nl) # somehow the backward of this operation is very slow at evaluation, # where there is only one cell, therefore we simplify the calculation @@ -420,13 +426,7 @@ def forward( def requested_neighbor_lists( self, ) -> List[NeighborListOptions]: - return [ - NeighborListOptions( - cutoff=self.hypers["cutoff"], - full_list=True, - strict=True, - ) - ] + return [self.requested_nl] @classmethod def load_checkpoint(cls, path: Union[str, Path]) -> "NanoPET": diff --git a/src/metatrain/experimental/nanopet/modules/structures.py b/src/metatrain/experimental/nanopet/modules/structures.py index 8228bf693..939ecf2b3 100644 --- a/src/metatrain/experimental/nanopet/modules/structures.py +++ b/src/metatrain/experimental/nanopet/modules/structures.py @@ -1,10 +1,12 @@ from typing import List import torch -from metatensor.torch.atomistic import System +from metatensor.torch.atomistic import NeighborListOptions, System -def concatenate_structures(systems: List[System]): +def concatenate_structures( + systems: List[System], neighbor_list_options: NeighborListOptions +): positions = [] centers = [] @@ -19,7 +21,7 @@ def concatenate_structures(systems: List[System]): species.append(system.types) assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" - neighbor_list = system.get_neighbor_list(system.known_neighbor_lists()[0]) + neighbor_list = system.get_neighbor_list(neighbor_list_options) nl_values = neighbor_list.samples.values centers.append(nl_values[:, 0] + node_counter)