From 48a50583fc9c0067be9e64fa7ae622f0ada51ca3 Mon Sep 17 00:00:00 2001 From: Arslan Mazitov Date: Thu, 25 Jul 2024 21:39:13 +0200 Subject: [PATCH] Updated the composition weights calculator --- src/metatrain/utils/composition.py | 19 ++++++++++++------- src/metatrain/utils/data/dataset.py | 6 +++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/composition.py index c3780c281..be2cca5a3 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/composition.py @@ -29,15 +29,20 @@ def calculate_composition_weights( ) targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions - structure_list = [sample["system"] for dataset in datasets for sample in dataset] - - dtype = structure_list[0].positions.dtype + total_num_structures = sum([len(dataset) for dataset in datasets]) + dtype = datasets[0][0]["system"].positions.dtype composition_features = torch.empty( - (len(structure_list), len(atomic_types)), dtype=dtype + (total_num_structures, len(atomic_types)), dtype=dtype ) - for i, structure in enumerate(structure_list): - for j, s in enumerate(atomic_types): - composition_features[i, j] = torch.sum(structure.types == s) + structure_index = 0 + for dataset in datasets: + for sample in dataset: + structure = sample["system"] + for j, s in enumerate(atomic_types): + composition_features[structure_index, j] = torch.sum( + structure.types == s + ) + structure_index += 1 regularizer = 1e-20 while regularizer: diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index f5744a1a5..a4225c02c 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -381,13 +381,13 @@ def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> List[int]: if not isinstance(datasets, list): datasets = [datasets] - types = [] + types = set() for dataset in datasets: for index in range(len(dataset)): system = dataset[index]["system"] - types += system.types.tolist() + types.update(set(system.types.tolist())) - return sorted(set(types)) + return sorted(types) def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]: