Skip to content

Commit

Permalink
Changes from review
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 15, 2024
1 parent 53566b8 commit c78a638
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions src/metatrain/utils/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class Scaler(torch.nn.Module):
A class that scales the targets of regression problems to unit standard
deviation.
In most cases, this should be used in conjunction with a composition model
(that removes the multi-dimensional "mean" across the composition space) and/or
other additive models. See the `train_model` method for more details.
:param model_hypers: A dictionary of model hyperparameters. The paramater is ignored
and is only present to be consistent with the general model API.
:param dataset_info: An object containing information about the dataset, including
Expand All @@ -38,6 +42,8 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):

self.new_targets: Dict[str, TargetInfo] = dataset_info.targets
self.outputs: Dict[str, ModelOutput] = {}
# Initially, the scales are empty. They will be expanded as new outputs
# are registered with `_add_output`.
self.register_buffer("scales", torch.ones((0,), dtype=torch.float64))
self.output_name_to_output_index: Dict[str, int] = {}
for target_name, target_info in self.dataset_info.targets.items():
Expand Down Expand Up @@ -155,26 +161,16 @@ def forward(
:raises ValueError: If an output does not have a corresponding
scale in the scaler model.
"""

for output_name in outputs:
if output_name.startswith("mtt::aux::") or output_name == "features":
continue
if output_name not in self.outputs.keys():
raise ValueError(
f"output key {output_name} is not supported by this scaler "
"model."
)

scaled_outputs: Dict[str, TensorMap] = {}
for target_key, target in outputs.items():
if target_key.startswith("mtt::aux::") or target_key == "features":
scaled_outputs[target_key] = target
else:
if target_key in self.outputs:
scale = float(
self.scales[self.output_name_to_output_index[target_key]].item()
)
scaled_target = metatensor.torch.multiply(target, scale)
scaled_outputs[target_key] = scaled_target
else:
scaled_outputs[target_key] = target

Check warning on line 173 in src/metatrain/utils/scaler.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/scaler.py#L173

Added line #L173 was not covered by tests

return scaled_outputs

Expand Down

0 comments on commit c78a638

Please sign in to comment.