diff --git a/docs/src/advanced-concepts/auxiliary-outputs.rst b/docs/src/advanced-concepts/auxiliary-outputs.rst index 7707aa0e0..3645a8b6c 100644 --- a/docs/src/advanced-concepts/auxiliary-outputs.rst +++ b/docs/src/advanced-concepts/auxiliary-outputs.rst @@ -34,9 +34,9 @@ auxiliary outputs: +--------------------------------------------+-----------+------------------+-----+-----+---------+ | Auxiliary output | SOAP-BPNN | Alchemical Model | PET | GAP | NanoPET | +--------------------------------------------+-----------+------------------+-----+-----+---------+ -| ``mtt::aux::{target}_last_layer_features`` | Yes | No | Yes | No | Yes | +| ``mtt::aux::{target}_last_layer_features`` | Yes | No | No | No | Yes | +--------------------------------------------+-----------+------------------+-----+-----+---------+ -| ``features`` | Yes | No | Yes | No | Yes | +| ``features`` | Yes | No | No | No | Yes | +--------------------------------------------+-----------+------------------+-----+-----+---------+ The following tables show the metadata that will be provided for each of the diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index d30bf3c9c..7600ef915 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -16,7 +16,6 @@ from pet.pet import PET as RawPET from metatrain.utils.data import DatasetInfo -from metatrain.utils.data.target_info import is_auxiliary_output from ...utils.additive import ZBL from ...utils.dtype import dtype_to_str @@ -145,18 +144,7 @@ def forward( names=["_"], values=torch.tensor([[0]], device=predictions.device) ) - # output the last-layer features for the outputs, if requested: - if ( - f"mtt::aux::{self.target_name}_last_layer_features" in outputs - or "features" in outputs - ): - ll_output_name = f"mtt::aux::{self.target_name}_last_layer_features" - base_name = self.target_name - if ll_output_name in outputs and base_name not in outputs: - raise ValueError( - f"Features {ll_output_name} can only be requested " - f"if the corresponding output {base_name} is also requested." - ) + if "mtt::aux::last_layer_features" in outputs: ll_features = output["last_layer_features"] block = TensorBlock( values=ll_features, @@ -164,36 +152,19 @@ def forward( components=[], properties=Labels( names=["properties"], - values=torch.arange( - ll_features.shape[1], device=predictions.device - ).reshape(-1, 1), + values=torch.arange(ll_features.shape[1]).reshape(-1, 1), ), ) output_tmap = TensorMap( keys=empty_labels, blocks=[block], ) - if ll_output_name in outputs: - ll_features_options = outputs[ll_output_name] - if not ll_features_options.per_atom: - processed_output_tmap = metatensor.torch.sum_over_samples( - output_tmap, "atom" - ) - else: - processed_output_tmap = output_tmap - output_quantities[ll_output_name] = processed_output_tmap - if "features" in outputs: - features_options = outputs["features"] - if not features_options.per_atom: - processed_output_tmap = metatensor.torch.sum_over_samples( - output_tmap, "atom" - ) - else: - processed_output_tmap = output_tmap - output_quantities["features"] = processed_output_tmap + if not outputs["mtt::aux::last_layer_features"].per_atom: + output_tmap = metatensor.torch.sum_over_samples(output_tmap, "atom") + output_quantities["mtt::aux::last_layer_features"] = output_tmap for output_name in outputs: - if is_auxiliary_output(output_name): + if output_name.startswith("mtt::aux::"): continue # skip auxiliary outputs (not targets) energy_labels = Labels( names=["energy"], values=torch.tensor([[0]], device=predictions.device) @@ -218,7 +189,7 @@ def forward( systems, outputs, selected_atoms ) for output_name in output_quantities: - if is_auxiliary_output(output_name): + if output_name.startswith("mtt::aux::"): continue # skip auxiliary outputs (not targets) output_quantities[output_name] = metatensor.torch.add( output_quantities[output_name], @@ -278,7 +249,7 @@ def export(self) -> MetatensorAtomisticModel: unit=self.dataset_info.targets[self.target_name].unit, per_atom=False, ), - f"mtt::aux::{self.target_name.replace('mtt::', '')}_last_layer_features": ModelOutput( # noqa: E501 + "mtt::aux::last_layer_features": ModelOutput( unit="unitless", per_atom=True ), }, diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 8fbafb5e2..b32287bb5 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -253,8 +253,8 @@ def test_vector_output(per_atom): WrappedPET(DEFAULT_HYPERS["model"], dataset_info) -def test_output_features(): - """Tests that the model can output its features and last-layer features.""" +def test_output_last_layer_features(): + """Tests that the model can output its last layer features.""" dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], @@ -288,27 +288,23 @@ def test_output_features(): [system], { "energy": ModelOutput(quantity="energy", unit="eV", per_atom=True), - "mtt::aux::energy_last_layer_features": ll_output_options, - "features": ll_output_options, + "mtt::aux::last_layer_features": ll_output_options, }, ) assert "energy" in outputs - assert "mtt::aux::energy_last_layer_features" in outputs - assert "features" in outputs - last_layer_features = outputs["mtt::aux::energy_last_layer_features"].block() - features = outputs["features"].block() - assert last_layer_features.samples.names == ["system", "atom"] + assert "mtt::aux::last_layer_features" in outputs + last_layer_features = outputs["mtt::aux::last_layer_features"].block() + assert last_layer_features.samples.names == [ + "system", + "atom", + ] assert last_layer_features.values.shape == ( 4, 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) ) - assert last_layer_features.properties.names == ["properties"] - assert features.samples.names == ["system", "atom"] - assert features.values.shape == ( - 4, - 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) - ) - assert features.properties.names == ["properties"] + assert last_layer_features.properties.names == [ + "properties", + ] # last-layer features per system: ll_output_options = ModelOutput( @@ -320,26 +316,16 @@ def test_output_features(): [system], { "energy": ModelOutput(quantity="energy", unit="eV", per_atom=True), - "mtt::aux::energy_last_layer_features": ll_output_options, - "features": ll_output_options, + "mtt::aux::last_layer_features": ll_output_options, }, ) assert "energy" in outputs - assert "mtt::aux::energy_last_layer_features" in outputs - assert "features" in outputs - assert outputs["mtt::aux::energy_last_layer_features"].block().samples.names == [ - "system" - ] - assert outputs["mtt::aux::energy_last_layer_features"].block().values.shape == ( + assert "mtt::aux::last_layer_features" in outputs + assert outputs["mtt::aux::last_layer_features"].block().samples.names == ["system"] + assert outputs["mtt::aux::last_layer_features"].block().values.shape == ( 1, 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) ) - assert outputs["mtt::aux::energy_last_layer_features"].block().properties.names == [ + assert outputs["mtt::aux::last_layer_features"].block().properties.names == [ "properties", ] - assert outputs["features"].block().samples.names == ["system"] - assert outputs["features"].block().values.shape == ( - 1, - 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) - ) - assert outputs["features"].block().properties.names == ["properties"]