From e1587196760d403edaae6c948a8b5e95984cb710 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 20 Nov 2024 08:14:09 +0100 Subject: [PATCH] Correctly inherit properties from targets in SOAP-BPNN --- src/metatrain/experimental/soap_bpnn/model.py | 16 ++++------------ src/metatrain/utils/additive/composition.py | 3 ++- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 8991bae5f..c74b38d8d 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -432,14 +432,10 @@ def _add_output(self, target_name: str, target: TargetInfo) -> None: values=torch.tensor(self.atomic_types).reshape(-1, 1), ), in_features=self.n_inputs_last_layer, - out_features=1, + out_features=len(target.layout.block().properties.values), bias=False, out_properties=[ - Labels( - names=["energy"], - values=torch.tensor([[0]]), - ) - for _ in self.atomic_types + target.layout.block().properties for _ in self.atomic_types ], ) else: @@ -456,14 +452,10 @@ def _add_output(self, target_name: str, target: TargetInfo) -> None: ), ), in_features=self.n_inputs_last_layer, - out_features=1, + out_features=len(target.layout.block().properties.values), bias=False, out_properties=[ - Labels( - names=["properties"], - values=torch.tensor([[0]]), - ) - for _ in self.atomic_types + target.layout.block().properties for _ in self.atomic_types ], ) diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index 7018a31be..2661ec639 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -44,8 +44,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): per_atom=True, ) for key, target_info in dataset_info.targets.items() - if target_info.is_scalar + if target_info.is_scalar and len(target_info.layout.block().properties) == 1 # important: only scalars can have composition contributions + # for now, we also require that only one property is present } n_types = len(self.atomic_types)