diff --git a/docs/src/architectures/nanopet.rst b/docs/src/architectures/nanopet.rst index 7b0515678..5a1fa3697 100644 --- a/docs/src/architectures/nanopet.rst +++ b/docs/src/architectures/nanopet.rst @@ -102,6 +102,8 @@ The hyperparameters for training are :param scheduler_factor: Factor to reduce the learning rate by :param log_interval: Interval at which to log training metrics :param checkpoint_interval: Interval at which to save model checkpoints +:param scale_targets: Whether to scale the targets to have unit standard deviation + across the training set during training. :param fixed_composition_weights: Weights for fixed atomic contributions to scalar targets :param per_structure_targets: Targets to calculate per-structure losses for diff --git a/docs/src/architectures/soap-bpnn.rst b/docs/src/architectures/soap-bpnn.rst index 7e3c08082..022f52414 100644 --- a/docs/src/architectures/soap-bpnn.rst +++ b/docs/src/architectures/soap-bpnn.rst @@ -159,6 +159,8 @@ The parameters for training are :param learning_rate: learning rate :param log_interval: number of epochs that elapse between reporting new training results :param checkpoint_interval: Interval to save a checkpoint to disk. +:param scale_targets: Whether to scale the targets to have unit standard deviation + across the training set during training. :param fixed_composition_weights: allows to set fixed isolated atom energies from outside. These are per target name and per (integer) atom type. For example, ``fixed_composition_weights: {"energy": {1: -396.0, 6: -500.0}, "mtt::U0": {1: 0.0, diff --git a/docs/src/dev-docs/utils/scaler.rst b/docs/src/dev-docs/utils/scaler.rst new file mode 100644 index 000000000..e893012c1 --- /dev/null +++ b/docs/src/dev-docs/utils/scaler.rst @@ -0,0 +1,7 @@ +Scaler +###### + +.. automodule:: metatrain.utils.scaler + :members: + :undoc-members: + :show-inheritance: diff --git a/src/metatrain/experimental/nanopet/default-hypers.yaml b/src/metatrain/experimental/nanopet/default-hypers.yaml index 66326116c..56388c236 100644 --- a/src/metatrain/experimental/nanopet/default-hypers.yaml +++ b/src/metatrain/experimental/nanopet/default-hypers.yaml @@ -22,6 +22,7 @@ architecture: scheduler_factor: 0.8 log_interval: 10 checkpoint_interval: 100 + scale_targets: true fixed_composition_weights: {} per_structure_targets: [] log_mae: False diff --git a/src/metatrain/experimental/nanopet/model.py b/src/metatrain/experimental/nanopet/model.py index 511dc4452..a28e7e663 100644 --- a/src/metatrain/experimental/nanopet/model.py +++ b/src/metatrain/experimental/nanopet/model.py @@ -1,4 +1,3 @@ -import copy from math import prod from pathlib import Path from typing import Dict, List, Optional, Union @@ -18,6 +17,7 @@ from ...utils.additive import ZBL, CompositionModel from ...utils.data import DatasetInfo, TargetInfo from ...utils.dtype import dtype_to_str +from ...utils.scaler import Scaler from .modules.encoder import Encoder from .modules.nef import ( edge_array_to_nef, @@ -126,6 +126,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self.head_types = self.hypers["heads"] self.last_layers = torch.nn.ModuleDict() self.output_shapes: Dict[str, List[int]] = {} + self.key_labels: Dict[str, Labels] = {} + self.component_labels: Dict[str, List[Labels]] = {} + self.property_labels: Dict[str, Labels] = {} for target_name, target_info in dataset_info.targets.items(): self._add_output(target_name, target_info) @@ -158,24 +161,10 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: additive_models.append(ZBL(model_hypers, dataset_info)) self.additive_models = torch.nn.ModuleList(additive_models) - # cache keys, components, properties labels + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(model_hypers={}, dataset_info=dataset_info) + self.single_label = Labels.single() - self.key_labels = { - output_name: copy.deepcopy(dataset_info.targets[output_name].layout.keys) - for output_name in self.dataset_info.targets.keys() - } - self.component_labels = { - output_name: copy.deepcopy( - dataset_info.targets[output_name].layout.block().components - ) - for output_name in self.dataset_info.targets.keys() - } - self.property_labels = { - output_name: copy.deepcopy( - dataset_info.targets[output_name].layout.block().properties - ) - for output_name in self.dataset_info.targets.keys() - } def restart(self, dataset_info: DatasetInfo) -> "NanoPET": # merge old and new dataset info @@ -188,6 +177,7 @@ def restart(self, dataset_info: DatasetInfo) -> "NanoPET": for key, value in merged_info.targets.items() if key not in self.dataset_info.targets } + self.has_new_targets = len(new_targets) > 0 if len(new_atomic_types) > 0: raise ValueError( @@ -200,7 +190,10 @@ def restart(self, dataset_info: DatasetInfo) -> "NanoPET": self._add_output(target_name, target) self.dataset_info = merged_info - self.atomic_types = sorted(self.atomic_types) + + # restart the composition and scaler models + self.additive_models[0].restart(dataset_info) + self.scaler.restart(dataset_info) return self @@ -465,7 +458,8 @@ def forward( ) if not self.training: - # at evaluation, we also add the additive contributions + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler(return_dict) for additive_model in self.additive_models: outputs_for_additive_model: Dict[str, ModelOutput] = {} for name, output in outputs.items(): @@ -566,3 +560,7 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None: prod(self.output_shapes[target_name]), bias=False, ) + + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = target_info.layout.block().components + self.property_labels[target_name] = target_info.layout.block().properties diff --git a/src/metatrain/experimental/nanopet/schema-hypers.json b/src/metatrain/experimental/nanopet/schema-hypers.json index 543961362..bb2d11816 100644 --- a/src/metatrain/experimental/nanopet/schema-hypers.json +++ b/src/metatrain/experimental/nanopet/schema-hypers.json @@ -73,6 +73,9 @@ "checkpoint_interval": { "type": "integer" }, + "scale_targets": { + "type": "boolean" + }, "fixed_composition_weights": { "type": "object", "patternProperties": { diff --git a/src/metatrain/experimental/nanopet/tests/test_regression.py b/src/metatrain/experimental/nanopet/tests/test_regression.py index 853fe32d8..891be0d2e 100644 --- a/src/metatrain/experimental/nanopet/tests/test_regression.py +++ b/src/metatrain/experimental/nanopet/tests/test_regression.py @@ -115,11 +115,11 @@ def test_regression_train(): expected_output = torch.tensor( [ - [-0.162086367607], - [-0.022639824077], - [0.000784186646], - [0.019549313933], - [0.063824169338], + [-0.016902115196], + [0.100093543530], + [0.038387011737], + [0.097679324448], + [0.118228666484], ] ) diff --git a/src/metatrain/experimental/nanopet/trainer.py b/src/metatrain/experimental/nanopet/trainer.py index 6ab6e8d27..291370e14 100644 --- a/src/metatrain/experimental/nanopet/trainer.py +++ b/src/metatrain/experimental/nanopet/trainer.py @@ -24,6 +24,7 @@ get_system_with_neighbor_lists, ) from ...utils.per_atom import average_by_num_atoms +from ...utils.scaler import remove_scale from .model import NanoPET from .modules.augmentation import apply_random_augmentations @@ -107,6 +108,10 @@ def train( train_datasets, self.hypers["fixed_composition_weights"] ) + if self.hypers["scale_targets"]: + logger.info("Calculating scaling weights") + model.scaler.train_model(train_datasets, model.additive_models) + if is_distributed: model = DistributedDataParallel(model, device_ids=[device]) @@ -207,7 +212,10 @@ def train( model.parameters(), lr=self.hypers["learning_rate"] ) if self.optimizer_state_dict is not None: - optimizer.load_state_dict(self.optimizer_state_dict) + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not model.has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) # Create a scheduler: lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -216,7 +224,9 @@ def train( patience=self.hypers["scheduler_patience"], ) if self.scheduler_state_dict is not None: - lr_scheduler.load_state_dict(self.scheduler_state_dict) + # same as the optimizer, try to load the scheduler state dict + if not model.has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) # per-atom targets: per_structure_targets = self.hypers["per_structure_targets"] @@ -274,6 +284,9 @@ def systems_and_targets_to_dtype( targets = remove_additive( systems, targets, additive_model, train_targets ) + targets = remove_scale( + targets, (model.module if is_distributed else model).scaler + ) systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) predictions = evaluate_model( model, @@ -330,6 +343,9 @@ def systems_and_targets_to_dtype( targets = remove_additive( systems, targets, additive_model, train_targets ) + targets = remove_scale( + targets, (model.module if is_distributed else model).scaler + ) systems = [system.to(dtype=dtype) for system in systems] targets = {key: value.to(dtype=dtype) for key, value in targets.items()} predictions = evaluate_model( @@ -377,6 +393,9 @@ def systems_and_targets_to_dtype( } if epoch == start_epoch: + scaler_scales = ( + model.module if is_distributed else model + ).scaler.get_scales_dict() metric_logger = MetricLogger( log_obj=logger, dataset_info=( @@ -384,6 +403,14 @@ def systems_and_targets_to_dtype( ).dataset_info, initial_metrics=[finalized_train_info, finalized_val_info], names=["training", "validation"], + scales={ + key: ( + scaler_scales[key.split(" ")[0]] + if ("MAE" in key or "RMSE" in key) + else 1.0 + ) + for key in finalized_train_info.keys() + }, ) if epoch % self.hypers["log_interval"] == 0: metric_logger.log( diff --git a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml index eea7ab870..85c6cf3f6 100644 --- a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml +++ b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml @@ -34,6 +34,7 @@ architecture: scheduler_factor: 0.8 log_interval: 5 checkpoint_interval: 25 + scale_targets: true fixed_composition_weights: {} per_structure_targets: [] log_mae: False diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index ca5b73d3f..9fe1e699e 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -21,6 +21,7 @@ from ...utils.additive import ZBL, CompositionModel from ...utils.dtype import dtype_to_str +from ...utils.scaler import Scaler class Identity(torch.nn.Module): @@ -297,6 +298,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: additive_models.append(ZBL(model_hypers, dataset_info)) self.additive_models = torch.nn.ModuleList(additive_models) + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(model_hypers={}, dataset_info=dataset_info) + def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": # merge old and new dataset info merged_info = self.dataset_info.union(dataset_info) @@ -308,6 +312,7 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": for key, value in merged_info.targets.items() if key not in self.dataset_info.targets } + self.has_new_targets = len(new_targets) > 0 if len(new_atomic_types) > 0: raise ValueError( @@ -320,7 +325,10 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": self._add_output(target_name, target) self.dataset_info = merged_info - self.atomic_types = sorted(self.atomic_types) + + # restart the composition and scaler models + self.additive_models[0].restart(dataset_info) + self.scaler.restart(dataset_info) return self @@ -409,7 +417,8 @@ def forward( ) if not self.training: - # at evaluation, we also add the additive contributions + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler(return_dict) for additive_model in self.additive_models: outputs_for_additive_model: Dict[str, ModelOutput] = {} for name, output in outputs.items(): diff --git a/src/metatrain/experimental/soap_bpnn/schema-hypers.json b/src/metatrain/experimental/soap_bpnn/schema-hypers.json index 3e74f69a1..87d02a16d 100644 --- a/src/metatrain/experimental/soap_bpnn/schema-hypers.json +++ b/src/metatrain/experimental/soap_bpnn/schema-hypers.json @@ -130,6 +130,9 @@ "checkpoint_interval": { "type": "integer" }, + "scale_targets": { + "type": "boolean" + }, "fixed_composition_weights": { "type": "object", "patternProperties": { diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 270ae4d4b..fb9c299f6 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -23,6 +23,7 @@ get_system_with_neighbor_lists, ) from ...utils.per_atom import average_by_num_atoms +from ...utils.scaler import remove_scale from ...utils.transfer import ( systems_and_targets_to_device, systems_and_targets_to_dtype, @@ -120,6 +121,10 @@ def train( train_datasets, self.hypers["fixed_composition_weights"] ) + if self.hypers["scale_targets"]: + logger.info("Calculating scaling weights") + model.scaler.train_model(train_datasets, model.additive_models) + if is_distributed: model = DistributedDataParallel(model, device_ids=[device]) @@ -220,7 +225,10 @@ def train( model.parameters(), lr=self.hypers["learning_rate"] ) if self.optimizer_state_dict is not None: - optimizer.load_state_dict(self.optimizer_state_dict) + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not model.has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) # Create a scheduler: lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -230,7 +238,9 @@ def train( threshold=0.001, ) if self.scheduler_state_dict is not None: - lr_scheduler.load_state_dict(self.scheduler_state_dict) + # same as the optimizer, try to load the scheduler state dict + if not model.has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) # per-atom targets: per_structure_targets = self.hypers["per_structure_targets"] @@ -269,6 +279,9 @@ def train( targets = remove_additive( systems, targets, additive_model, train_targets ) + targets = remove_scale( + targets, (model.module if is_distributed else model).scaler + ) systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) predictions = evaluate_model( model, @@ -322,6 +335,9 @@ def train( targets = remove_additive( systems, targets, additive_model, train_targets ) + targets = remove_scale( + targets, (model.module if is_distributed else model).scaler + ) systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) predictions = evaluate_model( model, @@ -365,6 +381,9 @@ def train( finalized_val_info = {"loss": val_loss, **finalized_val_info} if epoch == start_epoch: + scaler_scales = ( + model.module if is_distributed else model + ).scaler.get_scales_dict() metric_logger = MetricLogger( log_obj=logger, dataset_info=( @@ -372,6 +391,14 @@ def train( ).dataset_info, initial_metrics=[finalized_train_info, finalized_val_info], names=["training", "validation"], + scales={ + key: ( + scaler_scales[key.split(" ")[0]] + if ("MAE" in key or "RMSE" in key) + else 1.0 + ) + for key in finalized_train_info.keys() + }, ) if epoch % self.hypers["log_interval"] == 0: metric_logger.log( diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index 5679c317e..b394a3781 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -11,19 +11,18 @@ class CompositionModel(torch.nn.Module): - """A simple model that calculates the energy based on the stoichiometry in a system. + """A simple model that calculates the contributions to scalar targets + based on the stoichiometry in a system. :param model_hypers: A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API. :param dataset_info: An object containing information about the dataset, including target quantities and atomic types. - - :raises ValueError: If any target quantity in the dataset info is not an energy-like - quantity. """ + weights: torch.Tensor outputs: Dict[str, ModelOutput] - output_to_output_index: Dict[str, int] + output_name_to_output_index: Dict[str, int] def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): super().__init__() @@ -45,27 +44,18 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): "Please report this issue and help us improve!" ) - self.outputs = { - key: ModelOutput( - quantity=target_info.quantity, - unit=target_info.unit, - per_atom=True, - ) - for key, target_info in dataset_info.targets.items() - } - - n_types = len(self.atomic_types) - n_targets = len(self.outputs) - - self.output_to_output_index = { - target: i - for i, target in enumerate(sorted(dataset_info.targets.keys())) - if target in self.outputs + self.new_targets = { + target_name: target_info + for target_name, target_info in dataset_info.targets.items() } self.register_buffer( - "weights", torch.zeros((n_targets, n_types), dtype=torch.float64) + "weights", torch.zeros((0, len(self.atomic_types)), dtype=torch.float64) ) + self.output_name_to_output_index: Dict[str, int] = {} + self.outputs: Dict[str, ModelOutput] = {} + for target_name, target_info in self.dataset_info.targets.items(): + self._add_output(target_name, target_info) # cache some labels self.keys_label = Labels.single() @@ -82,6 +72,7 @@ def train_model( :param fixed_weights: Optional fixed weights to use for the composition model, for one or more target quantities. + :raises ValueError: If the provided datasets contain unknown targets. :raises ValueError: If the provided datasets contain unknown atomic types. :raises RuntimeError: If the linear system to calculate the composition weights cannot be solved. @@ -110,8 +101,9 @@ def train_model( stacklevel=2, ) - # Fill the weights for each target in the dataset info - for target_key in self.output_to_output_index.keys(): + # Fill the weights for each "new" target (i.e. those that do not already + # have composition weights from a previous training run) + for target_key in self.new_targets: if target_key in fixed_weights: # The fixed weights are provided for this target. Use them: @@ -121,9 +113,11 @@ def train_model( f"atomic types {self.atomic_types}." ) - self.weights[self.output_to_output_index[target_key]] = torch.tensor( - [fixed_weights[target_key][i] for i in self.atomic_types], - dtype=self.weights.dtype, + self.weights[self.output_name_to_output_index[target_key]] = ( + torch.tensor( + [fixed_weights[target_key][i] for i in self.atomic_types], + dtype=self.weights.dtype, + ) ) else: datasets_with_target = [] @@ -182,7 +176,7 @@ def train_model( "ill-conditioned." ) try: - self.weights[self.output_to_output_index[target_key]] = ( + self.weights[self.output_name_to_output_index[target_key]] = ( torch.linalg.solve( composition_features.T @ composition_features + regularizer @@ -199,10 +193,6 @@ def train_model( regularizer *= 10.0 def restart(self, dataset_info: DatasetInfo) -> "CompositionModel": - """Restart the model with a new dataset info. - - :param dataset_info: New dataset information to be used. - """ for target_info in dataset_info.targets.values(): if not self.is_valid_target(target_info): raise ValueError( @@ -211,7 +201,31 @@ def restart(self, dataset_info: DatasetInfo) -> "CompositionModel": "Please report this issue and help us improve!" ) - return self({}, self.dataset_info.union(dataset_info)) + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The composition model does not support adding new atomic types." + ) + + self.new_targets = { + target_name: target_info + for target_name, target_info in merged_info.targets.items() + if target_name not in self.dataset_info.targets + } + + # register new outputs + for target_name, target in self.new_targets.items(): + self._add_output(target_name, target) + + self.dataset_info = merged_info + + return self def forward( self, @@ -240,7 +254,7 @@ def forward( self.properties_label = self.properties_label.to(device) for output_name in outputs: - if output_name not in self.output_to_output_index: + if output_name not in self.output_name_to_output_index: raise ValueError( f"output key {output_name} is not supported by this composition " "model." @@ -254,7 +268,7 @@ def forward( # number of atoms per atomic type. targets_out: Dict[str, TensorMap] = {} for target_key, target in outputs.items(): - weights = self.weights[self.output_to_output_index[target_key]] + weights = self.weights[self.output_name_to_output_index[target_key]] concatenated_types = torch.concatenate([system.types for system in systems]) targets = torch.empty(len(concatenated_types), dtype=dtype, device=device) @@ -299,6 +313,22 @@ def forward( return targets_out + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + n_types = len(self.atomic_types) + + # important: only scalars can have composition contributions + # for now, we also require that only one property is present + if target_info.is_scalar and len(target_info.layout.block().properties) == 1: + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + ) + self.weights = torch.concatenate( + [self.weights, torch.zeros((1, n_types), dtype=self.weights.dtype)] + ) + self.output_name_to_output_index[target_name] = len(self.weights) - 1 + @staticmethod def is_valid_target(target_info: TargetInfo) -> bool: """Finds if a ``TargetInfo`` object is compatible with a composition model. diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index 99a88e028..4e78ba44f 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -24,6 +24,7 @@ def __init__( dataset_info: Union[ModelCapabilities, DatasetInfo], initial_metrics: Union[Dict[str, float], List[Dict[str, float]]], names: Union[str, List[str]] = "", + scales: Optional[Dict[str, float]] = None, ): """ Simple interface to log training metrics logging instance. @@ -65,12 +66,17 @@ def __init__( self.names = names + if scales is None: + scales = {target_name: 1.0 for target_name in initial_metrics[0].keys()} + self.scales = scales + # Since the quantities are supposed to decrease, we want to store the # number of digits at the start of the training, so that we can align # the output later: self.digits = {} for name, metrics_dict in zip(names, initial_metrics): for key, value in metrics_dict.items(): + value *= scales[key] target_name = key.split(" ", 1)[0] if key == "loss": # losses will be printed in scientific notation @@ -111,7 +117,7 @@ def log( for name, metrics_dict in zip(self.names, metrics): for key in _sort_metric_names(metrics_dict.keys()): - value = metrics_dict[key] + value = metrics_dict[key] * self.scales[key] new_key = key if key != "loss": # special case: not a metric associated with a target diff --git a/src/metatrain/utils/scaler.py b/src/metatrain/utils/scaler.py new file mode 100644 index 000000000..a79ee4ab5 --- /dev/null +++ b/src/metatrain/utils/scaler.py @@ -0,0 +1,233 @@ +from typing import Dict, List, Union + +import metatensor.torch +import numpy as np +import torch +from metatensor.torch import TensorMap +from metatensor.torch.atomistic import ModelOutput + +from .additive import remove_additive +from .data import Dataset, DatasetInfo, TargetInfo, get_all_targets +from .jsonschema import validate + + +class Scaler(torch.nn.Module): + """ + A class that scales the targets of regression problems to unit standard + deviation. + + In most cases, this should be used in conjunction with a composition model + (that removes the multi-dimensional "mean" across the composition space) and/or + other additive models. See the `train_model` method for more details. + + :param model_hypers: A dictionary of model hyperparameters. The paramater is ignored + and is only present to be consistent with the general model API. + :param dataset_info: An object containing information about the dataset, including + target quantities and atomic types. + """ + + outputs: Dict[str, ModelOutput] + scales: torch.Tensor + + def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): + super().__init__() + + # `model_hypers` should be an empty dictionary + validate( + instance=model_hypers, + schema={"type": "object", "additionalProperties": False}, + ) + + self.dataset_info = dataset_info + + self.new_targets: Dict[str, TargetInfo] = dataset_info.targets + self.outputs: Dict[str, ModelOutput] = {} + # Initially, the scales are empty. They will be expanded as new outputs + # are registered with `_add_output`. + self.register_buffer("scales", torch.ones((0,), dtype=torch.float64)) + self.output_name_to_output_index: Dict[str, int] = {} + for target_name, target_info in self.dataset_info.targets.items(): + self._add_output(target_name, target_info) + + def train_model( + self, + datasets: List[Union[Dataset, torch.utils.data.Subset]], + additive_models: List[torch.nn.Module], + ) -> None: + """ + Calculate the scaling weights for all the targets in the datasets. + + :param datasets: Dataset(s) to calculate the scaling weights for. + :param additive_models: Additive models to be removed from the targets + before calculating the statistics. + + :raises ValueError: If the provided datasets contain targets unknown + to the scaler. + """ + if not isinstance(datasets, list): + datasets = [datasets] + + # Fill the scales for each "new" target (i.e. those that do not already + # have scales from a previous training run) + for target_key in self.new_targets: + + datasets_with_target = [] + for dataset in datasets: + if target_key in get_all_targets(dataset): + datasets_with_target.append(dataset) + if len(datasets_with_target) == 0: + raise ValueError( + f"Target {target_key} in the model's new capabilities is not " + "present in any of the training datasets." + ) + + sum_of_squared_targets = 0.0 + total_num_elements = 0 + + for dataset in datasets_with_target: + for sample in dataset: + systems = [sample["system"]] + targets = {target_key: sample[target_key]} + + for additive_model in additive_models: + target_info_dict = {target_key: self.new_targets[target_key]} + targets = remove_additive( + systems, + targets, + additive_model, + target_info_dict, + ) + + target_info = self.new_targets[target_key] + if ( + target_info.quantity == "energy" + and "positions" in target_info.gradients + ): + # special case: here we want to scale with respect to the forces + # rather than the energies + sum_of_squared_targets += torch.sum( + targets[target_key].block().gradient("positions").values + ** 2 + ).item() + total_num_elements += ( + targets[target_key] + .block() + .gradient("positions") + .values.numel() + ) + else: + sum_of_squared_targets += sum( + torch.sum(block.values**2).item() + for block in targets[target_key].blocks() + ) + total_num_elements += sum( + block.values.numel() + for block in targets[target_key].blocks() + ) + + self.scales[self.output_name_to_output_index[target_key]] = np.sqrt( + sum_of_squared_targets / total_num_elements + ) + + def restart(self, dataset_info: DatasetInfo) -> "Scaler": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + + self.new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + + # register new outputs + for target_name, target in self.new_targets.items(): + self._add_output(target_name, target) + + self.dataset_info = merged_info + + return self + + def forward( + self, + outputs: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Scales all the targets in the outputs dictionary back to their + original scale. + + :param outputs: A dictionary of target quantities and their values + to be scaled. + + :raises ValueError: If an output does not have a corresponding + scale in the scaler model. + """ + scaled_outputs: Dict[str, TensorMap] = {} + for target_key, target in outputs.items(): + if target_key in self.outputs: + scale = float( + self.scales[self.output_name_to_output_index[target_key]].item() + ) + scaled_target = metatensor.torch.multiply(target, scale) + scaled_outputs[target_key] = scaled_target + else: + scaled_outputs[target_key] = target + + return scaled_outputs + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + ) + + self.scales = torch.cat( + [self.scales, torch.tensor([1.0], dtype=self.scales.dtype)] + ) + self.output_name_to_output_index[target_name] = len(self.scales) - 1 + + def get_scales_dict(self) -> Dict[str, float]: + """ + Return a dictionary with the scales for each output and output gradient. + + :return: A dictionary with the scales for each output and output gradient. + These correspond to the standard deviation of the targets in the + original dataset. The scales for each output gradient are the same + as the corresponding output. + """ + + scales_dict = { + output_name: self.scales[output_index].item() + for output_name, output_index in self.output_name_to_output_index.items() + } + # Add gradients if present. They have the same scale as the corresponding output + for output_name in list(scales_dict.keys()): + gradient_names_for_output = self.dataset_info.targets[output_name].gradients + for gradient_name in gradient_names_for_output: + scales_dict[output_name + "_" + gradient_name + "_gradients"] = ( + scales_dict[output_name] + ) + return scales_dict + + +def remove_scale( + targets: Dict[str, TensorMap], + scaler: Scaler, +): + """ + Scale all targets to a standard deviation of one. + + :param targets: Dictionary containing the targets to be scaled. + :param scaler: The scaler used to scale the targets. + """ + scaled_targets = {} + for target_key in targets.keys(): + scale = float( + scaler.scales[scaler.output_name_to_output_index[target_key]].item() + ) + scaled_targets[target_key] = metatensor.torch.multiply( + targets[target_key], 1.0 / scale + ) + + return scaled_targets diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index 5ba102cfa..433989745 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -93,7 +93,7 @@ def test_composition_model_train(): composition_model.train_model(dataset) assert composition_model.weights.shape[0] == 1 assert composition_model.weights.shape[1] == 2 - assert composition_model.output_to_output_index == {"energy": 0} + assert composition_model.output_name_to_output_index == {"energy": 0} assert composition_model.atomic_types == [1, 8] torch.testing.assert_close( composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64) @@ -102,7 +102,7 @@ def test_composition_model_train(): composition_model.train_model([dataset]) assert composition_model.weights.shape[0] == 1 assert composition_model.weights.shape[1] == 2 - assert composition_model.output_to_output_index == {"energy": 0} + assert composition_model.output_name_to_output_index == {"energy": 0} assert composition_model.atomic_types == [1, 8] torch.testing.assert_close( composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64) @@ -111,7 +111,7 @@ def test_composition_model_train(): composition_model.train_model([dataset, dataset, dataset]) assert composition_model.weights.shape[0] == 1 assert composition_model.weights.shape[1] == 2 - assert composition_model.output_to_output_index == {"energy": 0} + assert composition_model.output_name_to_output_index == {"energy": 0} assert composition_model.atomic_types == [1, 8] torch.testing.assert_close( composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64) diff --git a/tests/utils/test_scaler.py b/tests/utils/test_scaler.py new file mode 100644 index 000000000..5c3527096 --- /dev/null +++ b/tests/utils/test_scaler.py @@ -0,0 +1,208 @@ +from pathlib import Path + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import System +from omegaconf import OmegaConf + +from metatrain.utils.data import Dataset, DatasetInfo +from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.scaler import Scaler, remove_scale + + +RESOURCES_PATH = Path(__file__).parents[1] / "resources" + + +def test_scaler_train(): + """Test the calculation of scaling weights.""" + + # Here we use three synthetic structures: + # - O atom, with an energy of 3.0 + # - H2O molecule, with an energy of 4.0 + # - H4O2 molecule, with an energy of 12.0 + # The expected standard deviation is 13/sqrt(3). + + systems = [ + System( + positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), + types=torch.tensor([8]), + cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), + ), + System( + positions=torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=torch.float64 + ), + types=torch.tensor([1, 1, 8]), + cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), + ), + System( + positions=torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + ], + dtype=torch.float64, + ), + types=torch.tensor([1, 1, 8, 1, 1, 8]), + cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), + ), + ] + energies = [3.0, 4.0, 12.0] + energies = [ + TensorMap( + keys=Labels(names=["_"], values=torch.tensor([[0]])), + blocks=[ + TensorBlock( + values=torch.tensor([[e]], dtype=torch.float64), + samples=Labels(names=["system"], values=torch.tensor([[i]])), + components=[], + properties=Labels(names=["energy"], values=torch.tensor([[0]])), + ) + ], + ) + for i, e in enumerate(energies) + ] + dataset = Dataset.from_dict({"system": systems, "energy": energies}) + + scaler = Scaler( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 8], + targets={"energy": get_energy_target_info({"unit": "eV"})}, + ), + ) + + scaler.train_model(dataset, additive_models=[]) + assert scaler.scales.shape == (1,) + assert scaler.output_name_to_output_index == {"energy": 0} + torch.testing.assert_close( + scaler.scales, torch.tensor([13.0 / 3**0.5], dtype=torch.float64) + ) + + scaler.train_model([dataset], additive_models=[]) + assert scaler.scales.shape == (1,) + assert scaler.output_name_to_output_index == {"energy": 0} + torch.testing.assert_close( + scaler.scales, torch.tensor([13.0 / 3**0.5], dtype=torch.float64) + ) + + scaler.train_model([dataset, dataset, dataset], additive_models=[]) + assert scaler.scales.shape == (1,) + assert scaler.output_name_to_output_index == {"energy": 0} + torch.testing.assert_close( + scaler.scales, torch.tensor([13.0 / 3**0.5], dtype=torch.float64) + ) + + +def test_scale(): + """Test the scaling of the scale, both at training and prediction + time.""" + + dataset_path = RESOURCES_PATH / "qm9_reduced_100.xyz" + systems = read_systems(dataset_path) + + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": dataset_path, + "file_format": ".xyz", + "reader": "ase", + "key": "U0", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": False, + "stress": False, + "virial": False, + } + } + targets, target_info = read_targets(OmegaConf.create(conf)) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) + + scaler = Scaler( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 6, 7, 8], + targets=target_info, + ), + ) + + scaler.train_model(dataset, additive_models=[]) + scale = scaler.scales[0].item() + + fake_output_or_target = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.tensor([[0], [1], [2]]), + ), + components=[], + properties=Labels.single(), + ) + ], + ) + fake_output_or_target = {"mtt::U0": fake_output_or_target} + + scaled_output = scaler(fake_output_or_target) + assert "mtt::U0" in scaled_output + torch.testing.assert_close( + scaled_output["mtt::U0"].block().values, + torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float64) * scale, + ) + + # Test the remove_scale function + scaled_output = remove_scale(fake_output_or_target, scaler) + assert "mtt::U0" in fake_output_or_target + torch.testing.assert_close( + scaled_output["mtt::U0"].block().values, + torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float64) / scale, + ) + + +def test_scaler_torchscript(tmpdir): + """Test the torchscripting, saving and loading of a scaler model.""" + + scaler = Scaler( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 8], + targets={"energy": get_energy_target_info({"unit": "eV"})}, + ), + ) + + fake_output = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.tensor([[0], [1], [2]]), + ), + components=[], + properties=Labels.single(), + ) + ], + ) + fake_output = {"energy": fake_output} + + scaler = torch.jit.script(scaler) + scaler(fake_output) + torch.jit.save(scaler, tmpdir / "scaler.pt") + scaler = torch.jit.load(tmpdir / "scaler.pt") + scaler(fake_output)