From 09e25501ea3c4791e4211f916a33ba88d660fd62 Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:04:44 +0100 Subject: [PATCH] Adapt models to handle generic targets (#386) --- ...argets.rst => fitting-generic-targets.rst} | 51 +++- docs/src/advanced-concepts/index.rst | 2 +- docs/src/dev-docs/new-architecture.rst | 9 + .../experimental/alchemical_model/model.py | 22 +- .../tests/test_functionality.py | 34 ++- src/metatrain/experimental/gap/model.py | 8 +- .../experimental/gap/tests/test_errors.py | 31 +- src/metatrain/experimental/pet/model.py | 17 +- .../pet/tests/test_functionality.py | 31 +- src/metatrain/experimental/soap_bpnn/model.py | 289 +++++++++++------- .../soap_bpnn/tests/test_functionality.py | 41 ++- .../soap_bpnn/tests/test_regression.py | 20 +- src/metatrain/utils/additive/composition.py | 36 ++- src/metatrain/utils/additive/remove.py | 13 +- src/metatrain/utils/additive/zbl.py | 8 + tests/cli/test_train_model.py | 17 ++ tests/utils/test_additive.py | 52 ++-- tests/utils/test_llpr.py | 4 +- 18 files changed, 507 insertions(+), 178 deletions(-) rename docs/src/advanced-concepts/{preparing-generic-targets.rst => fitting-generic-targets.rst} (83%) diff --git a/docs/src/advanced-concepts/preparing-generic-targets.rst b/docs/src/advanced-concepts/fitting-generic-targets.rst similarity index 83% rename from docs/src/advanced-concepts/preparing-generic-targets.rst rename to docs/src/advanced-concepts/fitting-generic-targets.rst index 949e8f9f8..427ac0769 100644 --- a/docs/src/advanced-concepts/preparing-generic-targets.rst +++ b/docs/src/advanced-concepts/fitting-generic-targets.rst @@ -1,11 +1,52 @@ -Preparing generic targets for reading by metatrain -================================================== +Fitting generic targets +======================= Besides energy-like targets, the library also supports reading (and training on) more generic targets. +Support for generic targets +--------------------------- + +Not all architectures can train on all types of target. Here you can find the +capabilities of the architectures in metatrain. + +.. list-table:: Sample Table + :header-rows: 1 + + * - Target type + - Energy and its gradients + - Scalars + - Spherical tensors + - Cartesian tensors + * - SOAP-BPNN + - Energy, forces, stress/virial + - Yes + - Only with ``o3_lambda=1, o3_sigma=1`` + - No + * - Alchemical Model + - Energy, forces, stress/virial + - No + - No + - No + * - GAP + - Energy, forces + - No + - No + - No + * - PET + - Energy, forces + - No + - No + - No + + +Preparing generic targets for reading by metatrain +-------------------------------------------------- + +Only a few steps are required to fit arbitrary targets in metatrain. + Input file ----------- +########## In order to read a generic target, you will have to specify its layout in the input file. Suppose you want to learn a target named ``mtt::my_target``, which is @@ -69,7 +110,7 @@ where ``o3_lambda`` specifies the L value of the spherical tensor and ``o3_sigma parity with respect to inversion (1 for proper tensors, -1 for pseudo-tensors). Preparing your targets -- ASE ------------------------------ +############################# If you are using the ASE readers to read your targets, you will have to save them either in the ``.info`` (if the target is per structure, i.e. not per atom) or in the @@ -84,7 +125,7 @@ Reading targets with more than one spherical tensor is not supported by the ASE In that case, you should use the metatensor reader. Preparing your targets -- metatensor ------------------------------------- +#################################### If you are using the metatensor readers to read your targets, you will have to save them as a ``metatensor.torch.TensorMap`` object with ``metatensor.torch.TensorMap.save()`` diff --git a/docs/src/advanced-concepts/index.rst b/docs/src/advanced-concepts/index.rst index dd5f98f01..82fffb0a6 100644 --- a/docs/src/advanced-concepts/index.rst +++ b/docs/src/advanced-concepts/index.rst @@ -13,4 +13,4 @@ such as output naming, auxiliary outputs, and wrapper models. multi-gpu auto-restarting fine-tuning - preparing-generic-targets + fitting-generic-targets diff --git a/docs/src/dev-docs/new-architecture.rst b/docs/src/dev-docs/new-architecture.rst index aaa27cce3..68b9df0d9 100644 --- a/docs/src/dev-docs/new-architecture.rst +++ b/docs/src/dev-docs/new-architecture.rst @@ -227,3 +227,12 @@ passed to the architecture model and trainer as is. To create such a schema start by using `online tools `_ that convert the ``default-hypers.yaml`` into a JSON schema. Besides online tools, we also had success using ChatGPT/LLM for this for conversion. + +Documentation +------------- + +Each new architecture should be added to ``metatrain``'s documentation. A short page +describing the architecture and its default hyperparameters will be sufficient. You +can take inspiration from existing architectures. The various targets that the +architecture can fit should be added to the table in the "Fitting generic targets" +section. diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 767897f05..746259f1d 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -32,14 +32,24 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self.atomic_types = dataset_info.atomic_types if len(dataset_info.targets) != 1: - raise ValueError("The AlchemicalModel only supports a single target") + raise ValueError("The Alchemical Model only supports a single target") target_name = next(iter(dataset_info.targets.keys())) - if dataset_info.targets[target_name].quantity != "energy": - raise ValueError("The AlchemicalModel only supports 'energies' as target") - - if dataset_info.targets[target_name].per_atom: - raise ValueError("The AlchemicalModel does not support 'per-atom' training") + target = dataset_info.targets[target_name] + if not ( + target.is_scalar + and target.quantity == "energy" + and len(target.layout.block(0).properties) == 1 + ): + raise ValueError( + "The Alchemical Model only supports total-energy-like outputs, " + f"but a {target.quantity} was provided" + ) + if target.per_atom: + raise ValueError( + "Alchemical Model only supports per-structure outputs, " + "but a per-atom output was provided" + ) self.outputs = { key: ModelOutput( diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index e46ccac73..5bb185629 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -1,9 +1,13 @@ +import pytest import torch from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo -from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.data.target_info import ( + get_energy_target_info, + get_generic_target_info, +) from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, @@ -40,3 +44,31 @@ def test_prediction_subset_elements(): exported = model.export() exported([system], evaluation_options, check_consistency=True) + + +@pytest.mark.parametrize("per_atom", [True, False]) +def test_vector_output(per_atom): + """Tests that the model can predict a (spherical) vector output.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "forces": get_generic_target_info( + { + "quantity": "forces", + "unit": "", + "type": { + "spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]} + }, + "num_subtargets": 100, + "per_atom": per_atom, + } + ) + }, + ) + + with pytest.raises( + ValueError, match="The Alchemical Model only supports total-energy-like outputs" + ): + AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/gap/model.py b/src/metatrain/experimental/gap/model.py index fad045a69..95679e8a3 100644 --- a/src/metatrain/experimental/gap/model.py +++ b/src/metatrain/experimental/gap/model.py @@ -37,9 +37,13 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: # Check capabilities for target in dataset_info.targets.values(): - if target.quantity != "energy": + if not ( + target.is_scalar + and target.quantity == "energy" + and len(target.layout.block(0).properties) == 1 + ): raise ValueError( - "GAP only supports energy-like outputs, " + "GAP only supports total-energy-like outputs, " f"but a {target.quantity} was provided" ) if target.per_atom: diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index 3f959a869..91631678f 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -9,7 +9,10 @@ from metatrain.experimental.gap import GAP, Trainer from metatrain.utils.data import Dataset, DatasetInfo, read_systems, read_targets -from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.data.target_info import ( + get_energy_target_info, + get_generic_target_info, +) from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS @@ -85,3 +88,29 @@ def test_ethanol_regression_train_and_invariance(): val_datasets=[dataset], checkpoint_dir=".", ) + + +@pytest.mark.parametrize("per_atom", [True, False]) +def test_vector_output(per_atom): + """Tests that the model can predict a (spherical) vector output.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "forces": get_generic_target_info( + { + "quantity": "forces", + "unit": "", + "type": { + "spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]} + }, + "num_subtargets": 100, + "per_atom": per_atom, + } + ) + }, + ) + + with pytest.raises(ValueError, match="GAP only supports total-energy-like outputs"): + GAP(DEFAULT_HYPERS["model"], dataset_info) diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 01697436e..31f7d6ade 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -34,8 +34,21 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: if len(dataset_info.targets) != 1: raise ValueError("PET only supports a single target") self.target_name = next(iter(dataset_info.targets.keys())) - if dataset_info.targets[self.target_name].quantity != "energy": - raise ValueError("PET only supports energies as target") + target = dataset_info.targets[self.target_name] + if not ( + target.is_scalar + and target.quantity == "energy" + and len(target.layout.block(0).properties) == 1 + ): + raise ValueError( + "PET only supports total-energy-like outputs, " + f"but a {target.quantity} was provided" + ) + if target.per_atom: + raise ValueError( + "PET only supports per-structure outputs, " + "but a per-atom output was provided" + ) model_hypers["D_OUTPUT"] = 1 model_hypers["TARGET_TYPE"] = "atomic" diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index e0c43446d..47664be54 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -19,7 +19,10 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers from metatrain.utils.data import DatasetInfo -from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.data.target_info import ( + get_energy_target_info, + get_generic_target_info, +) from metatrain.utils.jsonschema import validate from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, @@ -222,3 +225,29 @@ def test_selected_atoms_functionality(): evaluation_options, check_consistency=True, ) + + +@pytest.mark.parametrize("per_atom", [True, False]) +def test_vector_output(per_atom): + """Tests that the model can predict a (spherical) vector output.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "forces": get_generic_target_info( + { + "quantity": "forces", + "unit": "", + "type": { + "spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]} + }, + "num_subtargets": 100, + "per_atom": per_atom, + } + ) + }, + ) + + with pytest.raises(ValueError, match="PET only supports total-energy-like outputs"): + WrappedPET(DEFAULT_HYPERS["model"], dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 093570723..c74b38d8d 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Dict, List, Optional, Union @@ -15,6 +16,7 @@ from metatensor.torch.learn.nn import Linear as LinearMap from metatensor.torch.learn.nn import ModuleMap +from metatrain.utils.data import TargetInfo from metatrain.utils.data.dataset import DatasetInfo from ...utils.additive import ZBL, CompositionModel @@ -29,6 +31,14 @@ def forward(self, x: TensorMap) -> TensorMap: return x +class IdentityWithExtraArg(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, s: List[System], x: TensorMap) -> TensorMap: + return x + + class MLPMap(ModuleMap): def __init__(self, atomic_types: List[int], hypers: dict) -> None: # hardcoded for now, but could be a hyperparameter @@ -93,6 +103,92 @@ def __init__(self, atomic_types: List[int], n_layer: int) -> None: super().__init__(in_keys, layernorm_per_species, out_properties) +class VectorFeaturizer(torch.nn.Module): + def __init__(self, atomic_types, num_features, soap_hypers) -> None: + super().__init__() + self.atomic_types = atomic_types + soap_vector_hypers = copy.deepcopy(soap_hypers) + soap_vector_hypers["max_angular"] = 1 + self.soap_calculator = rascaline.torch.SphericalExpansion( + radial_basis={"Gto": {}}, **soap_vector_hypers + ) + self.neighbors_species_labels = Labels( + names=["neighbor_type"], + values=torch.tensor(self.atomic_types).reshape(-1, 1), + ) + self.linear_layer = LinearMap( + Labels( + names=["o3_lambda", "o3_sigma", "center_type"], + values=torch.stack( + [ + torch.tensor([1] * len(self.atomic_types)), + torch.tensor([1] * len(self.atomic_types)), + torch.tensor(self.atomic_types), + ], + dim=1, + ), + ), + in_features=soap_vector_hypers["max_radial"] * len(self.atomic_types), + out_features=num_features, + bias=False, + out_properties=[ + Labels( + names=["property"], + values=torch.arange(num_features).reshape(-1, 1), + ) + for _ in self.atomic_types + ], + ) + + def forward(self, systems: List[System], scalar_features: TensorMap) -> TensorMap: + device = scalar_features.block(0).values.device + + spherical_expansion = self.soap_calculator(systems) + spherical_expansion = spherical_expansion.keys_to_properties( + self.neighbors_species_labels.to(device) + ) + + # drop all l=0 blocks + keys_to_drop_list: List[List[int]] = [] + for key in spherical_expansion.keys.values: + o3_lambda = int(key[0]) + o3_sigma = int(key[1]) + center_species = int(key[2]) + if o3_lambda == 0 and o3_sigma == 1: + keys_to_drop_list.append([o3_lambda, o3_sigma, center_species]) + keys_to_drop = Labels( + names=["o3_lambda", "o3_sigma", "center_type"], + values=torch.tensor(keys_to_drop_list, device=device), + ) + spherical_expansion = metatensor.torch.drop_blocks( + spherical_expansion, keys=keys_to_drop + ) + vector_features = self.linear_layer(spherical_expansion) + + overall_features = metatensor.torch.TensorMap( + keys=vector_features.keys, + blocks=[ + TensorBlock( + values=scalar_features.block( + {"center_type": int(ct)} + ).values.unsqueeze(1) + * vector_features.block({"center_type": int(ct)}).values + * 100.0, + samples=vector_features.block({"center_type": int(ct)}).samples, + components=vector_features.block( + {"center_type": int(ct)} + ).components, + properties=vector_features.block( + {"center_type": int(ct)} + ).properties, + ) + for ct in vector_features.keys.column("center_type") + ], + ) + + return overall_features + + class SoapBpnn(torch.nn.Module): __supported_devices__ = ["cuda", "cpu"] @@ -102,34 +198,12 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: super().__init__() self.hypers = model_hypers self.dataset_info = dataset_info - self.new_outputs = list(dataset_info.targets.keys()) self.atomic_types = dataset_info.atomic_types self.soap_calculator = rascaline.torch.SoapPowerSpectrum( radial_basis={"Gto": {}}, **self.hypers["soap"] ) - self.outputs = { - key: ModelOutput( - quantity=value.quantity, - unit=value.unit, - per_atom=True, - ) - for key, value in dataset_info.targets.items() - } - - # the model is always capable of outputting the last layer features - self.outputs["mtt::aux::last_layer_features"] = ModelOutput( - unit="unitless", per_atom=True - ) - - # buffers cannot be indexed by strings (torchscript), so we create a single - # tensor for all output. Due to this, we need to slice the tensor when we use - # it and use the output name to select the correct slice via a dictionary - self.output_to_index = { - output_name: i for i, output_name in enumerate(self.outputs.keys()) - } - soap_size = ( (len(self.atomic_types) * (len(self.atomic_types) + 1) // 2) * self.hypers["soap"]["max_radial"] ** 2 @@ -159,33 +233,19 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: ) if hypers_bpnn["num_hidden_layers"] == 0: - n_inputs_last_layer = hypers_bpnn["input_size"] + self.n_inputs_last_layer = hypers_bpnn["input_size"] else: - n_inputs_last_layer = hypers_bpnn["num_neurons_per_layer"] - - self.last_layer_feature_size = n_inputs_last_layer * len(self.atomic_types) - self.last_layers = torch.nn.ModuleDict( - { - output_name: LinearMap( - Labels( - "central_species", - values=torch.tensor(self.atomic_types).reshape(-1, 1), - ), - in_features=n_inputs_last_layer, - out_features=1, - bias=False, - out_properties=[ - Labels( - names=["energy"], - values=torch.tensor([[0]]), - ) - for _ in self.atomic_types - ], - ) - for output_name in self.outputs.keys() - if "mtt::aux::" not in output_name - } - ) + self.n_inputs_last_layer = hypers_bpnn["num_neurons_per_layer"] + + self.last_layer_feature_size = self.n_inputs_last_layer * len(self.atomic_types) + + self.outputs = { + "mtt::aux::last_layer_features": ModelOutput(unit="unitless", per_atom=True) + } # the model is always capable of outputting the last-layer features + self.vector_featurizers = torch.nn.ModuleDict({}) + self.last_layers = torch.nn.ModuleDict({}) + for target_name, target in dataset_info.targets.items(): + self._add_output(target_name, target) # additive models: these are handled by the trainer at training # time, and they are added to the output at evaluation time @@ -217,20 +277,12 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": ) # register new outputs as new last layers - for output_name in new_targets: - self.add_output(output_name) + for target_name, target in new_targets.items(): + self._add_output(target_name, target) self.dataset_info = merged_info self.atomic_types = sorted(self.atomic_types) - for target_name, target in new_targets.items(): - self.outputs[target_name] = ModelOutput( - quantity=target.quantity, - unit=target.unit, - per_atom=True, - ) - self.new_outputs = list(new_targets.keys()) - return self def forward( @@ -250,7 +302,6 @@ def forward( ) soap_features = self.layernorm(soap_features) - last_layer_features = self.bpnn(soap_features) # output the hidden features, if requested: @@ -265,31 +316,45 @@ def forward( _remove_center_type_from_properties(out_features) ) - atomic_energies: Dict[str, TensorMap] = {} + last_layer_features_by_output: Dict[str, TensorMap] = {} + for output_name, vector_featurizer in self.vector_featurizers.items(): + last_layer_features_by_output[output_name] = vector_featurizer( + systems, last_layer_features + ) + + atomic_properties: Dict[str, TensorMap] = {} for output_name, output_layer in self.last_layers.items(): if output_name in outputs: - atomic_energies[output_name] = output_layer(last_layer_features) + atomic_properties[output_name] = output_layer( + last_layer_features_by_output[output_name] + ) - # Sum the atomic energies coming from the BPNN to get the total energy - for output_name, atomic_energy in atomic_energies.items(): - atomic_energy = atomic_energy.keys_to_samples("center_type") + for output_name, atomic_property in atomic_properties.items(): + atomic_property = atomic_property.keys_to_samples("center_type") if outputs[output_name].per_atom: # this operation should just remove the center_type label return_dict[output_name] = metatensor.torch.remove_dimension( - atomic_energy, axis="samples", name="center_type" + atomic_property, axis="samples", name="center_type" ) else: + # sum the atomic property to get the total property return_dict[output_name] = metatensor.torch.sum_over_samples( - atomic_energy, ["atom", "center_type"] + atomic_property, ["atom", "center_type"] ) if not self.training: # at evaluation, we also add the additive contributions for additive_model in self.additive_models: + # some of the outputs might not be present in the additive model + # (e.g. the composition model only provides outputs for scalar targets) + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for output_name in outputs: + if output_name in additive_model.outputs: + outputs_for_additive_model[output_name] = outputs[output_name] additive_contributions = additive_model( - systems, outputs, selected_atoms + systems, outputs_for_additive_model, selected_atoms ) - for name in return_dict: + for name in additive_contributions: if name.startswith("mtt::aux::"): continue # skip auxiliary outputs (not targets) return_dict[name] = metatensor.torch.add( @@ -341,43 +406,63 @@ def export(self) -> MetatensorAtomisticModel: return MetatensorAtomisticModel(self.eval(), ModelMetadata(), capabilities) - def add_output(self, output_name: str) -> None: - """Add a new output to the self.""" - # add a new row to the composition weights tensor - # initialize it with zeros - self.composition_weights = torch.cat( - [ - self.composition_weights, # type: ignore - torch.zeros( - 1, - self.composition_weights.shape[1], # type: ignore - dtype=self.composition_weights.dtype, # type: ignore - device=self.composition_weights.device, # type: ignore + def _add_output(self, target_name: str, target: TargetInfo) -> None: + + if target.is_scalar: + self.vector_featurizers[target_name] = IdentityWithExtraArg() + elif target.is_spherical: + values_list: List[List[int]] = target.layout.keys.values.tolist() + if values_list != [[1, 1]]: + raise ValueError( + "SOAP-BPNN only supports spherical targets with " + "`o3_lambda=1` and `o3_sigma=1`, " + ) + self.vector_featurizers[target_name] = VectorFeaturizer( + atomic_types=self.atomic_types, + num_features=self.n_inputs_last_layer, + soap_hypers=self.hypers["soap"], + ) + else: + raise ValueError("SOAP-BPNN only supports scalar and spherical targets.") + + if target.is_scalar: + self.last_layers[target_name] = LinearMap( + Labels( + "central_species", + values=torch.tensor(self.atomic_types).reshape(-1, 1), ), - ] - ) - self.output_to_index[output_name] = len(self.output_to_index) - # add a new linear layer to the last layers - hypers_bpnn = self.hypers["bpnn"] - if hypers_bpnn["num_hidden_layers"] == 0: - n_inputs_last_layer = hypers_bpnn["input_size"] + in_features=self.n_inputs_last_layer, + out_features=len(target.layout.block().properties.values), + bias=False, + out_properties=[ + target.layout.block().properties for _ in self.atomic_types + ], + ) else: - n_inputs_last_layer = hypers_bpnn["num_neurons_per_layer"] - self.last_layers[output_name] = LinearMap( - Labels( - "central_species", - values=torch.tensor(self.atomic_types).reshape(-1, 1), - ), - in_features=n_inputs_last_layer, - out_features=1, - bias=False, - out_properties=[ + self.last_layers[target_name] = LinearMap( Labels( - names=["energy"], - values=torch.tensor([[0]]), - ) - for _ in self.atomic_types - ], + names=["o3_lambda", "o3_sigma", "center_type"], + values=torch.stack( + [ + torch.tensor([1] * len(self.atomic_types)), + torch.tensor([1] * len(self.atomic_types)), + torch.tensor(self.atomic_types), + ], + dim=1, + ), + ), + in_features=self.n_inputs_last_layer, + out_features=len(target.layout.block().properties.values), + bias=False, + out_properties=[ + target.layout.block().properties for _ in self.atomic_types + ], + ) + + self.outputs[target_name] = ModelOutput( + quantity=target.quantity, + unit=target.unit, + per_atom=True, ) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index 98336336f..bcccab073 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -8,7 +8,10 @@ from metatrain.experimental.soap_bpnn import SoapBpnn from metatrain.utils.architectures import check_architecture_options from metatrain.utils.data import DatasetInfo -from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.data.target_info import ( + get_energy_target_info, + get_generic_target_info, +) from . import DEFAULT_HYPERS, MODEL_HYPERS @@ -235,3 +238,39 @@ def test_fixed_composition_weights_error(): check_architecture_options( name="experimental.soap_bpnn", options=OmegaConf.to_container(hypers) ) + + +@pytest.mark.parametrize("per_atom", [True, False]) +def test_vector_output(per_atom): + """Tests that the model can predict a (spherical) vector output.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "forces": get_generic_target_info( + { + "quantity": "forces", + "unit": "", + "type": { + "spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]} + }, + "num_subtargets": 100, + "per_atom": per_atom, + } + ) + }, + ) + + model = SoapBpnn(MODEL_HYPERS, dataset_info) + + system = System( + types=torch.tensor([6, 6]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + model( + [system], + {"force": model.outputs["forces"]}, + ) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 61d08c11b..a881ec6cc 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -41,11 +41,11 @@ def test_regression_init(): expected_output = torch.tensor( [ - [-0.038599025458], - [0.111374437809], - [0.091115802526], - [-0.056339077652], - [-0.025491207838], + [-0.066850379109], + [-0.012763320468], + [-0.076015546918], + [0.041823804379], + [-0.022180110216], ] ) @@ -110,11 +110,11 @@ def test_regression_train(): expected_output = torch.tensor( [ - [-0.106249026954], - [0.039981484413], - [-0.142682999372], - [-0.031701669097], - [-0.016210660338], + [-0.080593876541], + [0.048118606210], + [0.037287645042], + [-0.000409360975], + [-0.039579294622], ] ) diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index c1b715cb9..2661ec639 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -22,6 +22,9 @@ class CompositionModel(torch.nn.Module): quantity. """ + outputs: Dict[str, ModelOutput] + output_to_output_index: Dict[str, int] + def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): super().__init__() @@ -31,31 +34,28 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): schema={"type": "object", "additionalProperties": False}, ) - # Check capabilities - for target in dataset_info.targets.values(): - if target.quantity != "energy": - raise ValueError( - "CompositionModel only supports energy-like outputs, but a " - f"{target.quantity} output was provided." - ) - self.dataset_info = dataset_info self.atomic_types = sorted(dataset_info.atomic_types) self.outputs = { key: ModelOutput( - quantity=value.quantity, - unit=value.unit, + quantity=target_info.quantity, + unit=target_info.unit, per_atom=True, ) - for key, value in dataset_info.targets.items() + for key, target_info in dataset_info.targets.items() + if target_info.is_scalar and len(target_info.layout.block().properties) == 1 + # important: only scalars can have composition contributions + # for now, we also require that only one property is present } n_types = len(self.atomic_types) - n_targets = len(dataset_info.targets) + n_targets = len(self.outputs) self.output_to_output_index = { - target: i for i, target in enumerate(sorted(dataset_info.targets.keys())) + target: i + for i, target in enumerate(sorted(dataset_info.targets.keys())) + if target in self.outputs } self.register_buffer( @@ -126,10 +126,13 @@ def train_model( if target_key in get_all_targets(dataset): datasets_with_target.append(dataset) if len(datasets_with_target) == 0: - raise ValueError( + # this is a possibility when transfer learning + warnings.warn( f"Target {target_key} in the model's new capabilities is not " - "present in any of the training datasets." + "present in any of the training datasets.", + stacklevel=2, ) + continue targets = torch.stack( [ @@ -242,6 +245,9 @@ def forward( for target_key, target in outputs.items(): if target_key.startswith("mtt::aux::"): continue + if target_key not in self.outputs.keys(): + # non-scalar + continue weights = self.weights[self.output_to_output_index[target_key]] concatenated_types = torch.concatenate([system.types for system in systems]) diff --git a/src/metatrain/utils/additive/remove.py b/src/metatrain/utils/additive/remove.py index 899a96a07..21678c714 100644 --- a/src/metatrain/utils/additive/remove.py +++ b/src/metatrain/utils/additive/remove.py @@ -35,11 +35,20 @@ def remove_additive( additive_contribution = evaluate_model( additive_model, systems, - {key: target_info_dict[key] for key in targets.keys()}, + { + key: target_info_dict[key] + for key in targets.keys() + if key in additive_model.outputs + }, is_training=False, # we don't need any gradients w.r.t. any parameters ) - for target_key in targets: + for target_key in additive_contribution.keys(): + # note that we loop over the keys of additive_contribution, not targets, + # because the targets might contain additional keys (this is for example + # the case of the composition model, which will only provide outputs + # for scalar targets + # make the samples the same so we can use metatensor.torch.subtract # we also need to detach the values to avoid backpropagating through the # subtraction diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py index cec2ee96e..abbef6341 100644 --- a/src/metatrain/utils/additive/zbl.py +++ b/src/metatrain/utils/additive/zbl.py @@ -42,6 +42,14 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): "ZBL only supports energy-like outputs, but a " f"{target.quantity} output was provided." ) + if not target.is_scalar: + raise ValueError("ZBL only supports scalar outputs") + if len(target.layout.block(0).properties) > 1: + raise ValueError( + "ZBL only supports outputs with one property, but " + f"{len(target.layout.block(0).properties)} " + "properties were provided." + ) if target.unit != "eV": raise ValueError( "ZBL only supports eV units, but a " diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index e6e398158..a17831e36 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -610,3 +610,20 @@ def test_train_log_order(caplog, monkeypatch, tmp_path, options): force_index = line.index("validation forces RMSE") virial_index = line.index("validation virial RMSE") assert force_index < virial_index + + +def test_train_generic_target(monkeypatch, tmp_path): + """Test training on a spherical vector target""" + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH_ETHANOL, "ethanol_reduced_100.xyz") + + # run training with original options + options = OmegaConf.load(OPTIONS_PATH) + options["training_set"]["systems"]["read_from"] = "ethanol_reduced_100.xyz" + options["training_set"]["targets"]["energy"]["type"] = { + "spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]} + } + options["training_set"]["targets"]["energy"]["per_atom"] = True + options["training_set"]["targets"]["energy"]["key"] = "forces" + + train_model(options) diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index cdffe31e1..bbe68232b 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -364,35 +364,33 @@ def test_composition_model_missing_types(): def test_composition_model_wrong_target(): """ - Test the error when a non-energy is fed to the composition model. + Test the error when a non-scalar is fed to the composition model. """ + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1], + targets={ + "force": get_generic_target_info( + { + "quantity": "force", + "unit": "", + "type": {"cartesian": {"rank": 1}}, + "num_subtargets": 1, + "per_atom": True, + } + ) + }, + ), + ) + # This should do nothing, because the target is not scalar and it should be + # ignored by the composition model. The warning is due to the "empty" dataset + # not containing H (atomic type 1) + with pytest.warns(UserWarning, match="do not contain atomic types"): + composition_model.train_model([]) - with pytest.raises( - ValueError, - match="only supports energy-like outputs", - ): - CompositionModel( - model_hypers={}, - dataset_info=DatasetInfo( - length_unit="angstrom", - atomic_types=[1], - targets={ - "energy": get_generic_target_info( - { - "quantity": "dipole", - "unit": "D", - "per_atom": True, - "num_subtargets": 5, - "type": { - "Cartesian": { - "rank": 1, - } - }, - } - ) - }, - ), - ) + assert composition_model.weights.shape == (0, 1) def test_zbl(): diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 280537562..ee586f199 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -242,7 +242,7 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir): params.append(param.squeeze()) weights = torch.cat(params) - n_ensemble_members = 10000 + n_ensemble_members = 1000000 # converges slowly... llpr_model.calibrate(dataloader) llpr_model.generate_ensemble({"energy": weights}, n_ensemble_members) assert "mtt::energy_ensemble" in llpr_model.capabilities.outputs @@ -282,5 +282,5 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir): ) torch.testing.assert_close( - analytical_uncertainty, ensemble_uncertainty, rtol=1e-2, atol=1e-2 + analytical_uncertainty, ensemble_uncertainty, rtol=5e-3, atol=0.0 )