Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass last-layer features from PET to metatrain-PET #407

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/src/advanced-concepts/auxiliary-outputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ by one or more architectures in the library:
The following table shows the architectures that support each of the
auxiliary outputs:

+------------------------------------------+-----------+------------------+-----+
| Auxiliary output | SOAP-BPNN | Alchemical Model | PET |
+------------------------------------------+-----------+------------------+-----+
| ``mtt::aux::last_layer_features`` | Yes | No | No |
+------------------------------------------+-----------+------------------+-----+
+------------------------------------------+-----------+------------------+-----+-----+
| Auxiliary output | SOAP-BPNN | Alchemical Model | PET | GAP |
+------------------------------------------+-----------+------------------+-----+-----+
| ``mtt::aux::last_layer_features`` | Yes | No | Yes | No |
+------------------------------------------+-----------+------------------+-----+-----+

The following tables show the metadata that is expected for each of the
auxiliary outputs:
Expand Down
57 changes: 46 additions & 11 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
self.is_lora_applied = False
self.checkpoint_path: Optional[str] = None

# last-layer feature size (for LLPR module)
self.last_layer_feature_size = (
self.hypers["N_GNN_LAYERS"]
* self.hypers["HEAD_N_NEURONS"]
* (1 + self.hypers["USE_BOND_ENERGIES"])
)
# if they are enabled, the edge features are concatenated
# to the node features

# additive models: these are handled by the trainer at training
# time, and they are added to the output at evaluation time
additive_models = []
Expand Down Expand Up @@ -123,20 +132,43 @@ def forward(
output = self.pet(batch) # type: ignore
predictions = output["prediction"]
output_quantities: Dict[str, TensorMap] = {}

structure_index = batch["batch"]
_, counts = torch.unique(batch["batch"], return_counts=True)
atom_index = torch.cat(
[torch.arange(count, device=predictions.device) for count in counts]
)
samples_values = torch.stack([structure_index, atom_index], dim=1)
samples = Labels(names=["system", "atom"], values=samples_values)
empty_labels = Labels(
names=["_"], values=torch.tensor([[0]], device=predictions.device)
)

if "mtt::aux::last_layer_features" in outputs:
ll_features = output["last_layer_features"]
block = TensorBlock(
values=ll_features,
samples=samples,
components=[],
properties=Labels(
names=["properties"],
values=torch.arange(ll_features.shape[1]).reshape(-1, 1),
),
)
output_tmap = TensorMap(
keys=empty_labels,
blocks=[block],
)
if not outputs["mtt::aux::last_layer_features"].per_atom:
output_tmap = metatensor.torch.sum_over_samples(output_tmap, "atom")
output_quantities["mtt::aux::last_layer_features"] = output_tmap

for output_name in outputs:
if output_name.startswith("mtt::aux::"):
continue # skip auxiliary outputs (not targets)
energy_labels = Labels(
names=["energy"], values=torch.tensor([[0]], device=predictions.device)
)
empty_labels = Labels(
names=["_"], values=torch.tensor([[0]], device=predictions.device)
)
structure_index = batch["batch"]
_, counts = torch.unique(batch["batch"], return_counts=True)
atom_index = torch.cat(
[torch.arange(count, device=predictions.device) for count in counts]
)
samples_values = torch.stack([structure_index, atom_index], dim=1)
samples = Labels(names=["system", "atom"], values=samples_values)
block = TensorBlock(
samples=samples,
components=[],
Expand Down Expand Up @@ -216,7 +248,10 @@ def export(self) -> MetatensorAtomisticModel:
quantity=self.dataset_info.targets[self.target_name].quantity,
unit=self.dataset_info.targets[self.target_name].unit,
per_atom=False,
)
),
"mtt::aux::last_layer_features": ModelOutput(
unit="unitless", per_atom=True
),
},
atomic_types=self.atomic_types,
interaction_range=interaction_range,
Expand Down
78 changes: 78 additions & 0 deletions src/metatrain/experimental/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,81 @@ def test_vector_output(per_atom):

with pytest.raises(ValueError, match="PET only supports total-energy-like outputs"):
WrappedPET(DEFAULT_HYPERS["model"], dataset_info)


def test_output_last_layer_features():
"""Tests that the model can output its last layer features."""
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={"energy": get_energy_target_info({"unit": "eV"})},
)

model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
ARCHITECTURAL_HYPERS = Hypers(model.hypers)
raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
model.set_trained_model(raw_pet)

system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]],
),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)

requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)

# last-layer features per atom:
ll_output_options = ModelOutput(
quantity="",
unit="unitless",
per_atom=True,
)
outputs = model(
[system],
{
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=True),
"mtt::aux::last_layer_features": ll_output_options,
},
)
assert "energy" in outputs
assert "mtt::aux::last_layer_features" in outputs
last_layer_features = outputs["mtt::aux::last_layer_features"].block()
assert last_layer_features.samples.names == [
"system",
"atom",
]
assert last_layer_features.values.shape == (
4,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert last_layer_features.properties.names == [
"properties",
]

# last-layer features per system:
ll_output_options = ModelOutput(
quantity="",
unit="unitless",
per_atom=False,
)
outputs = model(
[system],
{
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=True),
"mtt::aux::last_layer_features": ll_output_options,
},
)
assert "energy" in outputs
assert "mtt::aux::last_layer_features" in outputs
assert outputs["mtt::aux::last_layer_features"].block().samples.names == ["system"]
assert outputs["mtt::aux::last_layer_features"].block().values.shape == (
1,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert outputs["mtt::aux::last_layer_features"].block().properties.names == [
"properties",
]