diff --git a/src/metatensor/models/experimental/pet/tests/__init__.py b/src/metatensor/models/experimental/pet/tests/__init__.py index b6aa045b3..a652bfbc4 100644 --- a/src/metatensor/models/experimental/pet/tests/__init__.py +++ b/src/metatensor/models/experimental/pet/tests/__init__.py @@ -2,5 +2,5 @@ DATASET_PATH = str( Path(__file__).parent.resolve() - / "../../../../../../tests/resources/qm9_reduced_100.xyz" + / "../../../../../../tests/resources/alchemical_reduced_10.xyz" ) diff --git a/src/metatensor/models/experimental/pet/tests/test_pet_compatibility.py b/src/metatensor/models/experimental/pet/tests/test_pet_compatibility.py index f37e43d69..14debf801 100644 --- a/src/metatensor/models/experimental/pet/tests/test_pet_compatibility.py +++ b/src/metatensor/models/experimental/pet/tests/test_pet_compatibility.py @@ -14,15 +14,77 @@ from pet.hypers import Hypers from metatensor.models.experimental.pet import DEFAULT_HYPERS, Model +from metatensor.models.experimental.pet.utils import systems_to_batch_dict from metatensor.models.utils.neighbor_lists import get_system_with_neighbor_lists +from . import DATASET_PATH + + +def check_batch_dict_consistency(ref_batch, trial_batch): + ref_mask = ref_batch["mask"] + trial_mask = trial_batch["mask"] + assert torch.all(ref_mask == trial_mask) + mask = ref_mask is False + + for key in ref_batch: + if key == "x": + assert torch.allclose( + ref_batch["x"].flatten().sort()[0], + trial_batch["x"].flatten().sort()[0], + atol=1e-5, + ) + elif key in ("central_species", "mask", "nums", "batch"): + assert torch.all(ref_batch[key] == trial_batch[key]) + else: + ref_unique, ref_counts = ref_batch[key][mask].unique(return_counts=True) + trial_unique, trial_counts = trial_batch[key][mask].unique( + return_counts=True + ) + assert torch.all(ref_unique == trial_unique) + assert torch.all(ref_counts == trial_counts) + + +@pytest.mark.parametrize("cutoff", [0.25, 5.0]) +def test_batch_dicts_compatibility(cutoff): + """Tests that the batch dict computed with internal MTM routines + is consitent with PET implementation.""" + + structure = ase.io.read(DATASET_PATH) + all_species = sorted(list(set(structure.numbers))) + system = systems_to_torch(structure) + options = NeighborListOptions(cutoff=cutoff, full_list=True) + system = get_system_with_neighbor_lists(system, [options]) + + ARCHITECTURAL_HYPERS = Hypers(DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"]) + batch = get_pyg_graphs( + [structure], + all_species, + cutoff, + ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES, + ARCHITECTURAL_HYPERS.USE_LONG_RANGE, + ARCHITECTURAL_HYPERS.K_CUT, + )[0] + ref_batch_dict = { + "x": batch.x, + "central_species": batch.central_species, + "neighbor_species": batch.neighbor_species, + "mask": batch.mask, + "batch": torch.tensor([0] * len(batch.central_species)), + "nums": batch.nums, + "neighbors_index": batch.neighbors_index.transpose(0, 1), + "neighbors_pos": batch.neighbors_pos, + } + trial_batch_dict = systems_to_batch_dict([system], options, all_species, None) + check_batch_dict_consistency(ref_batch_dict, trial_batch_dict) + @pytest.mark.parametrize("cutoff", [0.25, 5.0]) def test_predictions_compatibility(cutoff): """Tests that predictions of the MTM implemetation of PET are consistent with the predictions of the original PET implementation.""" - all_species = [1, 6, 7, 8] + structure = ase.io.read(DATASET_PATH) + all_species = sorted(list(set(structure.numbers))) capabilities = ModelCapabilities( length_unit="Angstrom", @@ -41,7 +103,6 @@ def test_predictions_compatibility(cutoff): hypers = DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"] hypers["R_CUT"] = cutoff model = Model(capabilities, hypers) - structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) system = systems_to_torch(structure) options = NeighborListOptions(cutoff=cutoff, full_list=True) @@ -88,7 +149,6 @@ def test_predictions_compatibility(cutoff): pet = model._module.pet pet_prediction = pet.forward(batch_dict) - assert torch.allclose( mtm_pet_prediction, pet_prediction.sum(dim=0), diff --git a/src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py b/src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py index 7efe1f502..0efccba21 100644 --- a/src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py +++ b/src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py @@ -5,127 +5,6 @@ from metatensor.torch.atomistic import NeighborListOptions, System -class NeighborIndexConstructor: - """From a canonical neighbor list, this function constructs neighbor - indices that are needed for internal usage in the PET model.""" - - def __init__( - self, - i_list: List[int], - j_list: List[int], - S_list: List[torch.Tensor], - species: List[int], - ) -> None: - n_atoms: int = len(species) - - self.neighbors_index: List[List[int]] = [] - for _ in range(n_atoms): - neighbors_index_now: List[int] = [] - self.neighbors_index.append(neighbors_index_now) - - self.neighbor_shift: List[List[torch.Tensor]] = [] - for _ in range(n_atoms): - neighbor_shift_now: List[torch.Tensor] = [] - self.neighbor_shift.append(neighbor_shift_now) - - for i, j, _, S in zip(i_list, j_list, range(len(i_list)), S_list): - self.neighbors_index[i].append(j) - self.neighbor_shift[i].append(S) - - self.relative_positions_raw: List[List[torch.Tensor]] = [ - [] for i in range(n_atoms) - ] - self.neighbor_species: List[List[int]] = [] - for _ in range(n_atoms): - now: List[int] = [] - self.neighbor_species.append(now) - - self.neighbors_pos: List[List[torch.Tensor]] = [[] for i in range(n_atoms)] - - for i, j, index, S in zip(i_list, j_list, range(len(i_list)), S_list): - self.relative_positions_raw[i].append(torch.LongTensor([index])) - self.neighbor_species[i].append(species[j]) - for k in range(len(self.neighbors_index[j])): - if (self.neighbors_index[j][k] == i) and torch.equal( - self.neighbor_shift[j][k], -S - ): - self.neighbors_pos[i].append(torch.LongTensor([k])) - self.relative_positions: List[torch.Tensor] = [] - for chunk in self.relative_positions_raw: - if chunk: - self.relative_positions.append(torch.cat(chunk, dim=0)) - else: - self.relative_positions.append(torch.empty(0, dtype=torch.long)) - - def get_max_num(self) -> int: - maximum: int = -1 - for chunk in self.relative_positions: - if chunk.shape[0] > maximum: - maximum = chunk.shape[0] - return maximum - - def get_neighbors_index(self, max_num: int, all_species: torch.Tensor) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - nums_raw: List[int] = [] - mask_list: List[torch.Tensor] = [] - relative_positions: torch.Tensor = torch.zeros( - [len(self.relative_positions), max_num], dtype=torch.long - ) - neighbors_pos: torch.Tensor = torch.zeros( - [len(self.relative_positions), max_num], dtype=torch.long - ) - neighbors_index: torch.Tensor = torch.zeros( - [len(self.relative_positions), max_num], dtype=torch.long - ) - - for i in range(len(self.relative_positions)): - - now: torch.Tensor = self.relative_positions[i] - - if len(now) > 0: - relative_positions[i, : len(now)] = now - neighbors_pos[i, : len(now)] = torch.cat(self.neighbors_pos[i], dim=0) - neighbors_index[i, : len(now)] = torch.LongTensor( - self.neighbors_index[i] - ) - - nums_raw.append(len(self.relative_positions[i])) - current_mask: torch.Tensor = torch.zeros([max_num], dtype=torch.bool) - current_mask[len(self.relative_positions[i]) :] = True - mask_list.append(current_mask[None, :]) - - mask: torch.Tensor = torch.cat(mask_list, dim=0).to(dtype=torch.bool) - - nums: torch.Tensor = torch.LongTensor(nums_raw) - - neighbor_species: torch.Tensor = all_species.shape[0] * torch.ones( - [len(self.neighbor_species), max_num], dtype=torch.long - ) - for i in range(len(self.neighbor_species)): - species_now: List[int] = self.neighbor_species[i] - values_now: List[int] = [ - int(torch.where(all_species == specie)[0][0].item()) - for specie in species_now - ] - values_now_torch: torch.Tensor = torch.LongTensor(values_now) - neighbor_species[i, : len(values_now_torch)] = values_now_torch - - return ( - neighbors_pos, - neighbors_index, - nums, - mask, - neighbor_species, - relative_positions, - ) - - def collate_graph_dicts( graph_dicts: List[Dict[str, torch.Tensor]] ) -> Dict[str, torch.Tensor]: @@ -197,164 +76,84 @@ def collate_graph_dicts( return result_final -def systems_to_batch_dict( - systems: List[System], - options: NeighborListOptions, - all_species_list: List[int], - selected_atoms: Optional[Labels] = None, -) -> Dict[str, torch.Tensor]: +def get_max_num_neighbors(systems: List[System], options: NeighborListOptions): """ - Converts a standatd input data format of `metatensor-models` to a - PyTorch Geometric `Batch` object, compatible with `PET` model. + Calculates the maximum number of neighbors that atoms in a list of systems have. - :param systems: The list of systems in `metatensor.torch.atomistic.System` - format, that needs to be converted. - :param options: A `NeighborListOptions` objects specifying the parameters - for a neighbor list, which will be used during the convertation. - :param all_species: A `torch.Tensor` with all the species present in the - systems. - - :return: Batch compatible with PET. """ - device = systems[0].positions.device - all_species: torch.Tensor = torch.LongTensor(all_species_list).to(device) - neighbors_index_constructors: List[NeighborIndexConstructor] = [] - - for i, system in enumerate(systems): - known_neighbor_lists = system.known_neighbor_lists() - if not torch.any( - torch.tensor([known == options for known in known_neighbor_lists]) - ): - raise ValueError( - f"System does not have the neighbor list with the options {options}" + max_system_num_neighbors = [] + for system in systems: + nl = system.get_neighbor_list(options) + i_list = nl.samples.column("first_atom") + if len(i_list) == 0: + max_atom_num_neighbors = torch.tensor( + 0, device=i_list.device, dtype=i_list.dtype ) + else: + max_atom_num_neighbors = torch.bincount(i_list).max() + max_system_num_neighbors.append(max_atom_num_neighbors) + return int(torch.stack(max_system_num_neighbors).max().item()) - neighbor = system.get_neighbor_list(options) - i_list: torch.Tensor = neighbor.samples.column("first_atom") - j_list: torch.Tensor = neighbor.samples.column("second_atom") - unique_neighbors_index = torch.unique(torch.cat((i_list, j_list))) +def get_central_species( + system: System, all_species: torch.Tensor, unique_index: torch.Tensor +) -> torch.Tensor: + """ + Returns the indices of the species of the central atoms in the system + in a list of all species. - if selected_atoms is not None: - selected_atoms_index = selected_atoms.values[:, 1][ - selected_atoms.values[:, 0] == i - ] - unique_index = torch.unique( - torch.cat((selected_atoms_index, unique_neighbors_index)) - ) - else: - unique_index = torch.arange(len(system)) - - S_list_raw: List[torch.Tensor] = [ - neighbor.samples.column("cell_shift_a")[None], - neighbor.samples.column("cell_shift_b")[None], - neighbor.samples.column("cell_shift_c")[None], - ] - - S_list: torch.Tensor = torch.cat(S_list_raw) - S_list = S_list.transpose(0, 1) - - # unique_index = torch.unique(torch.cat((i_list, j_list))) - species: torch.Tensor = system.types[unique_index] - - # Remapping to contiguous indexing (see `remap_to_contiguous_indexing` - # docstring) - if (len(unique_neighbors_index) > 0) and ( - len(unique_neighbors_index) < i_list.max() - or len(unique_neighbors_index) < j_list.max() - ): - i_list, j_list = remap_to_contiguous_indexing(i_list, j_list, unique_index) - - i_list = i_list.cpu() - j_list = j_list.cpu() - S_list = S_list.cpu() - species = species.cpu() - - i_list_proper: List[int] = [int(el.item()) for el in i_list] - j_list_proper: List[int] = [int(el.item()) for el in j_list] - S_list_proper: List[torch.Tensor] = [el.to(dtype=torch.long) for el in S_list] - species_proper: List[int] = [int(el.item()) for el in species] - - neighbors_index_constructor: NeighborIndexConstructor = ( - NeighborIndexConstructor( - i_list_proper, j_list_proper, S_list_proper, species_proper - ) - ) - neighbors_index_constructors.append(neighbors_index_constructor) + """ + species = system.types[unique_index] + tmp_index_1, tmp_index_2 = torch.where(all_species.unsqueeze(1) == species) + index = torch.argsort(tmp_index_2) + return tmp_index_1[index] - max_nums: List[int] = [ - neighbors_index_constructor.get_max_num() - for neighbors_index_constructor in neighbors_index_constructors - ] - max_num: int = max(max_nums) - graphs: List[Dict[str, torch.Tensor]] = [] - for i, (neighbors_index_constructor, system) in enumerate( - zip(neighbors_index_constructors, systems) - ): +def write_system_data( + system: System, + options: NeighborListOptions, + selected_atoms_index: torch.Tensor, +): + nl = system.get_neighbor_list(options) + i_list = nl.samples.column("first_atom") + j_list = nl.samples.column("second_atom") + S_list = torch.cat( ( - neighbors_pos, - neighbors_index, - nums, - mask, - neighbor_species, - relative_positions_index, - ) = neighbors_index_constructor.get_neighbors_index(max_num, all_species) - - neighbor = system.get_neighbor_list(options) - displacement_vectors = neighbor.values[:, :, 0].to(torch.float32) - device = str(displacement_vectors.device) - neighbors_pos = neighbors_pos.to(device) - neighbors_index = neighbors_index.to(device) - nums = nums.to(device) - mask = mask.to(device) - neighbor_species = neighbor_species.to(device) - relative_positions_index = relative_positions_index.to(device) - if len(displacement_vectors) == 0: - shape = relative_positions_index.shape - relative_positions = torch.zeros( - size=(shape[0], shape[1], 3), - device=device, - dtype=torch.float32, - ) - else: - relative_positions = displacement_vectors[relative_positions_index] - if selected_atoms is not None: - neighbor = system.get_neighbor_list(options) - i_list = neighbor.samples.column("first_atom") - j_list = neighbor.samples.column("second_atom") - selected_atoms_index = selected_atoms.values[:, 1][ - selected_atoms.values[:, 0] == i - ] - unique_neighbors_index = torch.unique(torch.cat((i_list, j_list))) - unique_index = torch.unique( - torch.cat((selected_atoms_index, unique_neighbors_index)) - ) - else: - unique_index = torch.arange(len(system)) - species = system.types[unique_index] - central_species = [ - int(torch.where(all_species == specie)[0][0].item()) for specie in species - ] - - central_species = torch.LongTensor(central_species).to(device) - - graph_now = { - "central_species": central_species, - "x": relative_positions, - "neighbor_species": neighbor_species, - "neighbors_pos": neighbors_pos, - "neighbors_index": neighbors_index, - "nums": nums, - "mask": mask, - } - graphs.append(graph_now) - return collate_graph_dicts(graphs) + nl.samples.column("cell_shift_a")[None], + nl.samples.column("cell_shift_b")[None], + nl.samples.column("cell_shift_c")[None], + ) + ).transpose(0, 1) + D_list = nl.values[:, :, 0] + positions = system.positions + types = system.types + cell = system.cell + torch.save( + { + "i_list": i_list, + "j_list": j_list, + "S_list": S_list, + "D_list": D_list, + "positions": positions, + "types": types, + "cell": cell, + "selected_atoms_index": selected_atoms_index, + }, + "system_data.pt", + ) + + +def write_batch_dict(batch_dict: Dict[str, torch.Tensor]): + torch.save(batch_dict, "batch_dict.pt") def remap_to_contiguous_indexing( - i_list: torch.Tensor, j_list: torch.Tensor, unique_index: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: + i_list: torch.Tensor, + j_list: torch.Tensor, + unique_neighbors_index: torch.Tensor, + unique_index: torch.Tensor, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ This helper function remaps the indices of center and neighbor atoms from arbitrary indexing to contgious indexing, i.e. @@ -373,15 +172,226 @@ def remap_to_contiguous_indexing( remap the indices to a contiguous format. """ - index_map: Dict[int, int] = {int(index): i for i, index in enumerate(unique_index)} - i_list = torch.tensor( - [index_map[int(index)] for index in i_list], - dtype=i_list.dtype, - device=i_list.device, + index_map = torch.empty( + int(unique_index.max().item()) + 1, dtype=torch.int64, device=device + ) + index_map[unique_index] = torch.arange(len(unique_index), device=device) + i_list = index_map[i_list] + j_list = index_map[j_list] + unique_neighbors_index = index_map[unique_neighbors_index] + return i_list, j_list, unique_neighbors_index + + +def get_system_batch_dict( + system: System, + options: NeighborListOptions, + all_species: torch.Tensor, + max_num_neighbors: int, + selected_atoms_index: torch.Tensor, + device: torch.device, + debug: bool = False, +) -> Dict[str, torch.Tensor]: + if debug: + write_system_data(system, options, selected_atoms_index) + nl = system.get_neighbor_list(options) + i_list = nl.samples.column("first_atom") + j_list = nl.samples.column("second_atom") + + # First we need to get the unique indices of the atoms in the system. + # This includes all the atoms in the system and their neighbors. + unique_neighbors_index, counts = torch.unique(i_list, return_counts=True) + unique_index = torch.unique( + torch.cat((selected_atoms_index, unique_neighbors_index)) + ) + + # We calculate the actual size of the system, which is the number of + # unique atoms in the system. + # This is required for LAMMPS interface, because by default + # it produces the system with both local and ghost atoms. + actual_system_size = len(unique_index) + + # Then we remap the indices of the atoms to a contiguous format. + # Also see the docstring of the function for more details. + i_list, j_list, unique_neighbors_index = remap_to_contiguous_indexing( + i_list, j_list, unique_neighbors_index, unique_index, device + ) + + # We get the indices of species of the central atoms in the system + # in the all_species tensor. + central_species = get_central_species(system, all_species, unique_index) + + # We sort the indices of the atoms in the system, to join the + # periodic images of the same atom together. Otherwise, the + # neighbor list may have a discontinuous indexing, like: + # >>> i_list + # tensor([0, 1, 2, 0, 1, 2]) + # instead of + # >>> i_list + # tensor([0, 0, 1, 1, 2, 2]) + # and we heavily rely on the fact that the indices of the atoms + # are contiguous below. + index = torch.argsort(i_list, stable=True) + j_list = j_list[index] + i_list = i_list[index] + S_list: torch.Tensor = torch.cat( + ( + nl.samples.column("cell_shift_a")[None], + nl.samples.column("cell_shift_b")[None], + nl.samples.column("cell_shift_c")[None], + ) + ).transpose(0, 1)[index] + + D_list: torch.Tensor = nl.values[:, :, 0][index] + + # This calculates the number of neighbors for each atom. + # By default, the number of neighbors is zero, and we update this tensor + # with the counts of the unique indices of the i_list. + number_of_neighbors = torch.zeros( + actual_system_size, device=device, dtype=torch.int64 ) - j_list = torch.tensor( - [index_map[int(index)] for index in j_list], - dtype=j_list.dtype, - device=j_list.device, + number_of_neighbors[unique_neighbors_index] = counts + + # This calculates the cumulative sum of the counts to get the + # starting and ending indices of each atoms' neighbors in the + # j_list. + cum_sum = counts.cumsum(0) + cum_sum = torch.cat((torch.tensor([0]), cum_sum)) + + # We initialize the tensors for the neighbors indices, shifts and + # displacement vectors with zeros, and then for each atom we + # fill them with the corresponding values from the j_list, S_list + # and D_list. The padding_mask is used to mask the padding values + # in the tensors. + neighbors_index = torch.zeros( + (actual_system_size, max_num_neighbors), device=device, dtype=torch.int64 ) - return i_list, j_list + neighbors_shifts = torch.zeros( + (actual_system_size, max_num_neighbors, 3), device=device, dtype=torch.int64 + ) + displacement_vectors = torch.zeros( + (actual_system_size, max_num_neighbors, 3), device=device, dtype=torch.float32 + ) + padding_mask = torch.zeros( + (actual_system_size, max_num_neighbors), device=device, dtype=torch.bool + ) + for j, count in enumerate(counts): + # For each atom, we put the neighbors species indices up + # to the number of neighbors, while the rest of the indices + # are just padded with zeros. + neighbors_index[j, :count] = j_list[cum_sum[j] : cum_sum[j + 1]] + neighbors_shifts[j, :count] = S_list[cum_sum[j] : cum_sum[j + 1]] + displacement_vectors[j, :count] = D_list[cum_sum[j] : cum_sum[j + 1]] + padding_mask[j, count:] = True # padding mask is True for the padded values + + # We get the indices of the species of the neighbors in the all_species tensor. + # The reason why this function works, is because all the neighborlists are full. + # This means that the total number of central atoms is equal to the total number of + # neighbors. Therefore, `central_species` already contains all the necessacy + # indices, and we can just index it with the `neighbors_index`. + neighbor_species = central_species[neighbors_index] + + # We get the reversed neighbors index, which is used in the PET model to + # account for edge information update not only with the central atoms data, + # but also with the neighbors data. This requires knowing the reversed indices + # of the neighbors in the neighbor list. + # + # The reversed neighbor index is basically the index of the central atom in the + # neighbor list of the neighbor atom. + # + # Example: + # >>> neighbors_index + # tensor([[25, 28, 39, ...], + # ... + # [ 2, 3, 4, ...]]) + # >>> reversed_neighbors_index + # tensor([[ 3, 4, 1, ...], + # ... + # [ 3, 8, 7, ...]]) + # + # The first atom has the neighbors with indices 25, 28, 39, etc., + # That means, in the list of neighbors of 25th atom, 4th atom will + # have the index 0 (i.e. that will be the first atom). + # + # >>> neighbors_index[25] + # tensor([45, 29, 47, 0, ...]) + # >>> neighbors_index[25][3] + # tensor(0) + # + # and this is the element [0, 0] of the `reversed_neighbors_index`. + # + # We also demand the reversed cell shift vector to be the opposite + # of the original cell shift vector. This is because sometimes the + # central atom may have two neighbors, which are the same atom, but + # different periodic images. + + reversed_neighbors_index = torch.zeros_like(neighbors_index) + tmp_reversed_index = neighbors_index[neighbors_index] + tmp_reversed_shifts = neighbors_shifts[neighbors_index] + for j in range(actual_system_size): + condition_1 = tmp_reversed_index[j] == j + condition_2 = torch.all( + tmp_reversed_shifts[j] == -neighbors_shifts[j].unsqueeze(1), dim=2 + ) + condition = condition_1 & condition_2 + tmp_index_1, tmp_index_2 = torch.where(condition) + if len(tmp_index_1) > 0: + _, counts = torch.unique(tmp_index_1, return_counts=True) + cum_sum = counts.cumsum(0) + cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1])) + reversed_neighbors_index[j, : number_of_neighbors[j]] = tmp_index_2[ + cum_sum + ][: number_of_neighbors[j]] + system_dict = { + "central_species": central_species, + "x": displacement_vectors, + "neighbor_species": neighbor_species, + "neighbors_pos": reversed_neighbors_index, + "neighbors_index": neighbors_index, + "nums": number_of_neighbors, + "mask": padding_mask, + } + if debug: + write_batch_dict(system_dict) + return system_dict + + +def systems_to_batch_dict( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + selected_atoms: Optional[Labels] = None, +) -> Dict[str, torch.Tensor]: + """ + Converts a standard input data format of `metatensor-models` to a + PyTorch Geometric `Batch` object, compatible with `PET` model. + + :param systems: The list of systems in `metatensor.torch.atomistic.System` + format, that needs to be converted. + :param options: A `NeighborListOptions` objects specifying the parameters + for a neighbor list, which will be used during the convertation. + :param all_species: A `torch.Tensor` with all the species present in the + systems. + + :return: Batch compatible with PET. + """ + device = systems[0].positions.device + all_species = torch.tensor(all_species_list, device=device) + batch: List[Dict[str, torch.Tensor]] = [] + max_num_neighbors = get_max_num_neighbors(systems, options) + for i, system in enumerate(systems): + if selected_atoms is not None: + selected_atoms_index = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + else: + selected_atoms_index = torch.arange(len(system), device=device) + system_dict = get_system_batch_dict( + system, + options, + all_species, + max_num_neighbors, + selected_atoms_index, + device, + ) + batch.append(system_dict) + return collate_graph_dicts(batch)