Skip to content

Commit

Permalink
Remove PET changes
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 10, 2024
1 parent 0f0bd35 commit 3bf3a60
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 70 deletions.
4 changes: 2 additions & 2 deletions docs/src/advanced-concepts/auxiliary-outputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ auxiliary outputs:
+--------------------------------------------+-----------+------------------+-----+-----+---------+
| Auxiliary output | SOAP-BPNN | Alchemical Model | PET | GAP | NanoPET |
+--------------------------------------------+-----------+------------------+-----+-----+---------+
| ``mtt::aux::{target}_last_layer_features`` | Yes | No | Yes | No | Yes |
| ``mtt::aux::{target}_last_layer_features`` | Yes | No | No | No | Yes |
+--------------------------------------------+-----------+------------------+-----+-----+---------+
| ``features`` | Yes | No | Yes | No | Yes |
| ``features`` | Yes | No | No | No | Yes |
+--------------------------------------------+-----------+------------------+-----+-----+---------+

The following tables show the metadata that will be provided for each of the
Expand Down
45 changes: 8 additions & 37 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pet.pet import PET as RawPET

from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import is_auxiliary_output

from ...utils.additive import ZBL
from ...utils.dtype import dtype_to_str
Expand Down Expand Up @@ -145,55 +144,27 @@ def forward(
names=["_"], values=torch.tensor([[0]], device=predictions.device)
)

# output the last-layer features for the outputs, if requested:
if (
f"mtt::aux::{self.target_name}_last_layer_features" in outputs
or "features" in outputs
):
ll_output_name = f"mtt::aux::{self.target_name}_last_layer_features"
base_name = self.target_name
if ll_output_name in outputs and base_name not in outputs:
raise ValueError(
f"Features {ll_output_name} can only be requested "
f"if the corresponding output {base_name} is also requested."
)
if "mtt::aux::last_layer_features" in outputs:
ll_features = output["last_layer_features"]
block = TensorBlock(
values=ll_features,
samples=samples,
components=[],
properties=Labels(
names=["properties"],
values=torch.arange(
ll_features.shape[1], device=predictions.device
).reshape(-1, 1),
values=torch.arange(ll_features.shape[1]).reshape(-1, 1),
),
)
output_tmap = TensorMap(
keys=empty_labels,
blocks=[block],
)
if ll_output_name in outputs:
ll_features_options = outputs[ll_output_name]
if not ll_features_options.per_atom:
processed_output_tmap = metatensor.torch.sum_over_samples(
output_tmap, "atom"
)
else:
processed_output_tmap = output_tmap
output_quantities[ll_output_name] = processed_output_tmap
if "features" in outputs:
features_options = outputs["features"]
if not features_options.per_atom:
processed_output_tmap = metatensor.torch.sum_over_samples(
output_tmap, "atom"
)
else:
processed_output_tmap = output_tmap
output_quantities["features"] = processed_output_tmap
if not outputs["mtt::aux::last_layer_features"].per_atom:
output_tmap = metatensor.torch.sum_over_samples(output_tmap, "atom")
output_quantities["mtt::aux::last_layer_features"] = output_tmap

for output_name in outputs:
if is_auxiliary_output(output_name):
if output_name.startswith("mtt::aux::"):
continue # skip auxiliary outputs (not targets)
energy_labels = Labels(
names=["energy"], values=torch.tensor([[0]], device=predictions.device)
Expand All @@ -218,7 +189,7 @@ def forward(
systems, outputs, selected_atoms
)
for output_name in output_quantities:
if is_auxiliary_output(output_name):
if output_name.startswith("mtt::aux::"):
continue # skip auxiliary outputs (not targets)
output_quantities[output_name] = metatensor.torch.add(
output_quantities[output_name],
Expand Down Expand Up @@ -278,7 +249,7 @@ def export(self) -> MetatensorAtomisticModel:
unit=self.dataset_info.targets[self.target_name].unit,
per_atom=False,
),
f"mtt::aux::{self.target_name.replace('mtt::', '')}_last_layer_features": ModelOutput( # noqa: E501
"mtt::aux::last_layer_features": ModelOutput(
unit="unitless", per_atom=True
),
},
Expand Down
48 changes: 17 additions & 31 deletions src/metatrain/experimental/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def test_vector_output(per_atom):
WrappedPET(DEFAULT_HYPERS["model"], dataset_info)


def test_output_features():
"""Tests that the model can output its features and last-layer features."""
def test_output_last_layer_features():
"""Tests that the model can output its last layer features."""
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
Expand Down Expand Up @@ -288,27 +288,23 @@ def test_output_features():
[system],
{
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=True),
"mtt::aux::energy_last_layer_features": ll_output_options,
"features": ll_output_options,
"mtt::aux::last_layer_features": ll_output_options,
},
)
assert "energy" in outputs
assert "mtt::aux::energy_last_layer_features" in outputs
assert "features" in outputs
last_layer_features = outputs["mtt::aux::energy_last_layer_features"].block()
features = outputs["features"].block()
assert last_layer_features.samples.names == ["system", "atom"]
assert "mtt::aux::last_layer_features" in outputs
last_layer_features = outputs["mtt::aux::last_layer_features"].block()
assert last_layer_features.samples.names == [
"system",
"atom",
]
assert last_layer_features.values.shape == (
4,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert last_layer_features.properties.names == ["properties"]
assert features.samples.names == ["system", "atom"]
assert features.values.shape == (
4,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert features.properties.names == ["properties"]
assert last_layer_features.properties.names == [
"properties",
]

# last-layer features per system:
ll_output_options = ModelOutput(
Expand All @@ -320,26 +316,16 @@ def test_output_features():
[system],
{
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=True),
"mtt::aux::energy_last_layer_features": ll_output_options,
"features": ll_output_options,
"mtt::aux::last_layer_features": ll_output_options,
},
)
assert "energy" in outputs
assert "mtt::aux::energy_last_layer_features" in outputs
assert "features" in outputs
assert outputs["mtt::aux::energy_last_layer_features"].block().samples.names == [
"system"
]
assert outputs["mtt::aux::energy_last_layer_features"].block().values.shape == (
assert "mtt::aux::last_layer_features" in outputs
assert outputs["mtt::aux::last_layer_features"].block().samples.names == ["system"]
assert outputs["mtt::aux::last_layer_features"].block().values.shape == (
1,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert outputs["mtt::aux::energy_last_layer_features"].block().properties.names == [
assert outputs["mtt::aux::last_layer_features"].block().properties.names == [
"properties",
]
assert outputs["features"].block().samples.names == ["system"]
assert outputs["features"].block().values.shape == (
1,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert outputs["features"].block().properties.names == ["properties"]

0 comments on commit 3bf3a60

Please sign in to comment.