diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 35964d9dd..746259f1d 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -39,13 +39,17 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: if not ( target.is_scalar and target.quantity == "energy" - and "atom" not in target.layout.block(0).samples.names and len(target.layout.block(0).properties) == 1 ): raise ValueError( "The Alchemical Model only supports total-energy-like outputs, " f"but a {target.quantity} was provided" ) + if target.per_atom: + raise ValueError( + "Alchemical Model only supports per-structure outputs, " + "but a per-atom output was provided" + ) self.outputs = { key: ModelOutput( diff --git a/src/metatrain/experimental/gap/model.py b/src/metatrain/experimental/gap/model.py index e4f237acd..95679e8a3 100644 --- a/src/metatrain/experimental/gap/model.py +++ b/src/metatrain/experimental/gap/model.py @@ -40,7 +40,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: if not ( target.is_scalar and target.quantity == "energy" - and "atom" not in target.layout.block(0).samples.names and len(target.layout.block(0).properties) == 1 ): raise ValueError( diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 5a835563e..6da9fc036 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -41,13 +41,17 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: if not ( target.is_scalar and target.quantity == "energy" - and "atom" not in target.layout.block(0).samples.names and len(target.layout.block(0).properties) == 1 ): raise ValueError( "PET only supports total-energy-like outputs, " f"but a {target.quantity} was provided" ) + if target.per_atom: + raise ValueError( + "PET only supports per-structure outputs, " + "but a per-atom output was provided" + ) model_hypers["D_OUTPUT"] = 1 model_hypers["TARGET_TYPE"] = "atomic"