From c78a6380718ef6f38e8e2dd08f9b26f71857cf59 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sun, 15 Dec 2024 18:46:11 +0100 Subject: [PATCH] Changes from review --- src/metatrain/utils/scaler.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/metatrain/utils/scaler.py b/src/metatrain/utils/scaler.py index 51aa5c61c..a79ee4ab5 100644 --- a/src/metatrain/utils/scaler.py +++ b/src/metatrain/utils/scaler.py @@ -16,6 +16,10 @@ class Scaler(torch.nn.Module): A class that scales the targets of regression problems to unit standard deviation. + In most cases, this should be used in conjunction with a composition model + (that removes the multi-dimensional "mean" across the composition space) and/or + other additive models. See the `train_model` method for more details. + :param model_hypers: A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API. :param dataset_info: An object containing information about the dataset, including @@ -38,6 +42,8 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): self.new_targets: Dict[str, TargetInfo] = dataset_info.targets self.outputs: Dict[str, ModelOutput] = {} + # Initially, the scales are empty. They will be expanded as new outputs + # are registered with `_add_output`. self.register_buffer("scales", torch.ones((0,), dtype=torch.float64)) self.output_name_to_output_index: Dict[str, int] = {} for target_name, target_info in self.dataset_info.targets.items(): @@ -155,26 +161,16 @@ def forward( :raises ValueError: If an output does not have a corresponding scale in the scaler model. """ - - for output_name in outputs: - if output_name.startswith("mtt::aux::") or output_name == "features": - continue - if output_name not in self.outputs.keys(): - raise ValueError( - f"output key {output_name} is not supported by this scaler " - "model." - ) - scaled_outputs: Dict[str, TensorMap] = {} for target_key, target in outputs.items(): - if target_key.startswith("mtt::aux::") or target_key == "features": - scaled_outputs[target_key] = target - else: + if target_key in self.outputs: scale = float( self.scales[self.output_name_to_output_index[target_key]].item() ) scaled_target = metatensor.torch.multiply(target, scale) scaled_outputs[target_key] = scaled_target + else: + scaled_outputs[target_key] = target return scaled_outputs