diff --git a/docs/src/architectures/nanopet.rst b/docs/src/architectures/nanopet.rst index cc3ceb2e2..841d2b1b5 100644 --- a/docs/src/architectures/nanopet.rst +++ b/docs/src/architectures/nanopet.rst @@ -102,6 +102,8 @@ The hyperparameters for training are :param scheduler_factor: Factor to reduce the learning rate by :param log_interval: Interval at which to log training metrics :param checkpoint_interval: Interval at which to save model checkpoints +:param scale_targets: Whether to scale the targets to have unit standard deviation + across the training set during training. :param fixed_composition_weights: Weights for fixed atomic contributions to scalar targets :param per_structure_targets: Targets to calculate per-structure losses for diff --git a/src/metatrain/experimental/nanopet/model.py b/src/metatrain/experimental/nanopet/model.py index a3fb74e4b..a28e7e663 100644 --- a/src/metatrain/experimental/nanopet/model.py +++ b/src/metatrain/experimental/nanopet/model.py @@ -14,8 +14,6 @@ System, ) -from metatrain.utils.data.target_info import is_auxiliary_output - from ...utils.additive import ZBL, CompositionModel from ...utils.data import DatasetInfo, TargetInfo from ...utils.dtype import dtype_to_str @@ -463,18 +461,16 @@ def forward( # at evaluation, we also introduce the scaler and additive contributions return_dict = self.scaler(return_dict) for additive_model in self.additive_models: - # some of the outputs might not be present in the additive model - # (e.g. the composition model only provides outputs for scalar targets) outputs_for_additive_model: Dict[str, ModelOutput] = {} - for output_name in outputs: - if output_name in additive_model.outputs: - outputs_for_additive_model[output_name] = outputs[output_name] + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output additive_contributions = additive_model( - systems, outputs_for_additive_model, selected_atoms + systems, + outputs_for_additive_model, + selected_atoms, ) for name in additive_contributions: - if is_auxiliary_output(name): - continue return_dict[name] = metatensor.torch.add( return_dict[name], additive_contributions[name], diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 8fbafb5e2..b32287bb5 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -253,8 +253,8 @@ def test_vector_output(per_atom): WrappedPET(DEFAULT_HYPERS["model"], dataset_info) -def test_output_features(): - """Tests that the model can output its features and last-layer features.""" +def test_output_last_layer_features(): + """Tests that the model can output its last layer features.""" dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], @@ -288,27 +288,23 @@ def test_output_features(): [system], { "energy": ModelOutput(quantity="energy", unit="eV", per_atom=True), - "mtt::aux::energy_last_layer_features": ll_output_options, - "features": ll_output_options, + "mtt::aux::last_layer_features": ll_output_options, }, ) assert "energy" in outputs - assert "mtt::aux::energy_last_layer_features" in outputs - assert "features" in outputs - last_layer_features = outputs["mtt::aux::energy_last_layer_features"].block() - features = outputs["features"].block() - assert last_layer_features.samples.names == ["system", "atom"] + assert "mtt::aux::last_layer_features" in outputs + last_layer_features = outputs["mtt::aux::last_layer_features"].block() + assert last_layer_features.samples.names == [ + "system", + "atom", + ] assert last_layer_features.values.shape == ( 4, 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) ) - assert last_layer_features.properties.names == ["properties"] - assert features.samples.names == ["system", "atom"] - assert features.values.shape == ( - 4, - 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) - ) - assert features.properties.names == ["properties"] + assert last_layer_features.properties.names == [ + "properties", + ] # last-layer features per system: ll_output_options = ModelOutput( @@ -320,26 +316,16 @@ def test_output_features(): [system], { "energy": ModelOutput(quantity="energy", unit="eV", per_atom=True), - "mtt::aux::energy_last_layer_features": ll_output_options, - "features": ll_output_options, + "mtt::aux::last_layer_features": ll_output_options, }, ) assert "energy" in outputs - assert "mtt::aux::energy_last_layer_features" in outputs - assert "features" in outputs - assert outputs["mtt::aux::energy_last_layer_features"].block().samples.names == [ - "system" - ] - assert outputs["mtt::aux::energy_last_layer_features"].block().values.shape == ( + assert "mtt::aux::last_layer_features" in outputs + assert outputs["mtt::aux::last_layer_features"].block().samples.names == ["system"] + assert outputs["mtt::aux::last_layer_features"].block().values.shape == ( 1, 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) ) - assert outputs["mtt::aux::energy_last_layer_features"].block().properties.names == [ + assert outputs["mtt::aux::last_layer_features"].block().properties.names == [ "properties", ] - assert outputs["features"].block().samples.names == ["system"] - assert outputs["features"].block().values.shape == ( - 1, - 768, # 768 = 3 (gnn layers) * 256 (128 for edge repr, 128 for node repr) - ) - assert outputs["features"].block().properties.names == ["properties"] diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index fb93a52ff..b394a3781 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -47,7 +47,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): self.new_targets = { target_name: target_info for target_name, target_info in dataset_info.targets.items() - if target_name not in self.dataset_info.targets } self.register_buffer( diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index 6ef94f5d9..433989745 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -384,9 +384,6 @@ def test_composition_model_wrong_target(): ) }, ), - types=torch.tensor([1, 1, 8]), - cell=torch.eye(3, dtype=torch.float64), - pbc=torch.tensor([True, True, True]), )