Skip to content

Commit

Permalink
Correctly inherit properties from targets in SOAP-BPNN
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 20, 2024
1 parent 0c207fc commit e158719
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
16 changes: 4 additions & 12 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,10 @@ def _add_output(self, target_name: str, target: TargetInfo) -> None:
values=torch.tensor(self.atomic_types).reshape(-1, 1),
),
in_features=self.n_inputs_last_layer,
out_features=1,
out_features=len(target.layout.block().properties.values),
bias=False,
out_properties=[
Labels(
names=["energy"],
values=torch.tensor([[0]]),
)
for _ in self.atomic_types
target.layout.block().properties for _ in self.atomic_types
],
)
else:
Expand All @@ -456,14 +452,10 @@ def _add_output(self, target_name: str, target: TargetInfo) -> None:
),
),
in_features=self.n_inputs_last_layer,
out_features=1,
out_features=len(target.layout.block().properties.values),
bias=False,
out_properties=[
Labels(
names=["properties"],
values=torch.tensor([[0]]),
)
for _ in self.atomic_types
target.layout.block().properties for _ in self.atomic_types
],
)

Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/utils/additive/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):
per_atom=True,
)
for key, target_info in dataset_info.targets.items()
if target_info.is_scalar
if target_info.is_scalar and len(target_info.layout.block().properties) == 1
# important: only scalars can have composition contributions
# for now, we also require that only one property is present
}

n_types = len(self.atomic_types)
Expand Down

0 comments on commit e158719

Please sign in to comment.