Skip to content

Commit

Permalink
Fix stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 13, 2024
1 parent 0d91421 commit 2f8ed7e
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 45 deletions.
2 changes: 2 additions & 0 deletions docs/src/architectures/nanopet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ The hyperparameters for training are
:param scheduler_factor: Factor to reduce the learning rate by
:param log_interval: Interval at which to log training metrics
:param checkpoint_interval: Interval at which to save model checkpoints
:param scale_targets: Whether to scale the targets to have unit standard deviation
across the training set during training.
:param fixed_composition_weights: Weights for fixed atomic contributions to scalar
targets
:param per_structure_targets: Targets to calculate per-structure losses for
Expand Down
16 changes: 6 additions & 10 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
System,
)

from metatrain.utils.data.target_info import is_auxiliary_output

from ...utils.additive import ZBL, CompositionModel
from ...utils.data import DatasetInfo, TargetInfo
from ...utils.dtype import dtype_to_str
Expand Down Expand Up @@ -463,18 +461,16 @@ def forward(
# at evaluation, we also introduce the scaler and additive contributions
return_dict = self.scaler(return_dict)
for additive_model in self.additive_models:
# some of the outputs might not be present in the additive model
# (e.g. the composition model only provides outputs for scalar targets)
outputs_for_additive_model: Dict[str, ModelOutput] = {}
for output_name in outputs:
if output_name in additive_model.outputs:
outputs_for_additive_model[output_name] = outputs[output_name]
for name, output in outputs.items():
if name in additive_model.outputs:
outputs_for_additive_model[name] = output
additive_contributions = additive_model(
systems, outputs_for_additive_model, selected_atoms
systems,
outputs_for_additive_model,
selected_atoms,
)
for name in additive_contributions:
if is_auxiliary_output(name):
continue
return_dict[name] = metatensor.torch.add(
return_dict[name],
additive_contributions[name],
Expand Down
48 changes: 17 additions & 31 deletions src/metatrain/experimental/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def test_vector_output(per_atom):
WrappedPET(DEFAULT_HYPERS["model"], dataset_info)


def test_output_features():
"""Tests that the model can output its features and last-layer features."""
def test_output_last_layer_features():
"""Tests that the model can output its last layer features."""
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
Expand Down Expand Up @@ -288,27 +288,23 @@ def test_output_features():
[system],
{
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=True),
"mtt::aux::energy_last_layer_features": ll_output_options,
"features": ll_output_options,
"mtt::aux::last_layer_features": ll_output_options,
},
)
assert "energy" in outputs
assert "mtt::aux::energy_last_layer_features" in outputs
assert "features" in outputs
last_layer_features = outputs["mtt::aux::energy_last_layer_features"].block()
features = outputs["features"].block()
assert last_layer_features.samples.names == ["system", "atom"]
assert "mtt::aux::last_layer_features" in outputs
last_layer_features = outputs["mtt::aux::last_layer_features"].block()
assert last_layer_features.samples.names == [
"system",
"atom",
]
assert last_layer_features.values.shape == (
4,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert last_layer_features.properties.names == ["properties"]
assert features.samples.names == ["system", "atom"]
assert features.values.shape == (
4,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert features.properties.names == ["properties"]
assert last_layer_features.properties.names == [
"properties",
]

# last-layer features per system:
ll_output_options = ModelOutput(
Expand All @@ -320,26 +316,16 @@ def test_output_features():
[system],
{
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=True),
"mtt::aux::energy_last_layer_features": ll_output_options,
"features": ll_output_options,
"mtt::aux::last_layer_features": ll_output_options,
},
)
assert "energy" in outputs
assert "mtt::aux::energy_last_layer_features" in outputs
assert "features" in outputs
assert outputs["mtt::aux::energy_last_layer_features"].block().samples.names == [
"system"
]
assert outputs["mtt::aux::energy_last_layer_features"].block().values.shape == (
assert "mtt::aux::last_layer_features" in outputs
assert outputs["mtt::aux::last_layer_features"].block().samples.names == ["system"]
assert outputs["mtt::aux::last_layer_features"].block().values.shape == (
1,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert outputs["mtt::aux::energy_last_layer_features"].block().properties.names == [
assert outputs["mtt::aux::last_layer_features"].block().properties.names == [
"properties",
]
assert outputs["features"].block().samples.names == ["system"]
assert outputs["features"].block().values.shape == (
1,
768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr)
)
assert outputs["features"].block().properties.names == ["properties"]
1 change: 0 additions & 1 deletion src/metatrain/utils/additive/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):
self.new_targets = {
target_name: target_info
for target_name, target_info in dataset_info.targets.items()
if target_name not in self.dataset_info.targets
}

self.register_buffer(
Expand Down
3 changes: 0 additions & 3 deletions tests/utils/test_additive.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,6 @@ def test_composition_model_wrong_target():
)
},
),
types=torch.tensor([1, 1, 8]),
cell=torch.eye(3, dtype=torch.float64),
pbc=torch.tensor([True, True, True]),
)


Expand Down

0 comments on commit 2f8ed7e

Please sign in to comment.