Skip to content

Commit

Permalink
Fix NL extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 4, 2024
1 parent 47ac52a commit 94aa9b3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
16 changes: 8 additions & 8 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
self.new_outputs = list(dataset_info.targets.keys())
self.atomic_types = dataset_info.atomic_types

self.requested_nl = NeighborListOptions(
cutoff=self.hypers["cutoff"],
full_list=True,
strict=True,
)

self.cutoff = self.hypers["cutoff"]
self.cutoff_width = self.hypers["cutoff_width"]

Expand Down Expand Up @@ -250,7 +256,7 @@ def forward(
species,
cells,
cell_shifts,
) = concatenate_structures(systems)
) = concatenate_structures(systems, self.requested_nl)

# somehow the backward of this operation is very slow at evaluation,
# where there is only one cell, therefore we simplify the calculation
Expand Down Expand Up @@ -420,13 +426,7 @@ def forward(
def requested_neighbor_lists(
self,
) -> List[NeighborListOptions]:
return [
NeighborListOptions(
cutoff=self.hypers["cutoff"],
full_list=True,
strict=True,
)
]
return [self.requested_nl]

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "NanoPET":
Expand Down
8 changes: 5 additions & 3 deletions src/metatrain/experimental/nanopet/modules/structures.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List

import torch
from metatensor.torch.atomistic import System
from metatensor.torch.atomistic import NeighborListOptions, System


def concatenate_structures(systems: List[System]):
def concatenate_structures(
systems: List[System], neighbor_list_options: NeighborListOptions
):

positions = []
centers = []
Expand All @@ -19,7 +21,7 @@ def concatenate_structures(systems: List[System]):
species.append(system.types)

assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found"
neighbor_list = system.get_neighbor_list(system.known_neighbor_lists()[0])
neighbor_list = system.get_neighbor_list(neighbor_list_options)
nl_values = neighbor_list.samples.values

centers.append(nl_values[:, 0] + node_counter)
Expand Down

0 comments on commit 94aa9b3

Please sign in to comment.