Skip to content

Commit

Permalink
Merge branch 'main' into new-loss-default
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 16, 2024
2 parents 82904a7 + c3d51ae commit 112307a
Show file tree
Hide file tree
Showing 17 changed files with 627 additions and 70 deletions.
2 changes: 2 additions & 0 deletions docs/src/architectures/nanopet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ The hyperparameters for training are
:param scheduler_factor: Factor to reduce the learning rate by
:param log_interval: Interval at which to log training metrics
:param checkpoint_interval: Interval at which to save model checkpoints
:param scale_targets: Whether to scale the targets to have unit standard deviation
across the training set during training.
:param fixed_composition_weights: Weights for fixed atomic contributions to scalar
targets
:param per_structure_targets: Targets to calculate per-structure losses for
Expand Down
2 changes: 2 additions & 0 deletions docs/src/architectures/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ The parameters for training are
:param learning_rate: learning rate
:param log_interval: number of epochs that elapse between reporting new training results
:param checkpoint_interval: Interval to save a checkpoint to disk.
:param scale_targets: Whether to scale the targets to have unit standard deviation
across the training set during training.
:param fixed_composition_weights: allows to set fixed isolated atom energies from
outside. These are per target name and per (integer) atom type. For example,
``fixed_composition_weights: {"energy": {1: -396.0, 6: -500.0}, "mtt::U0": {1: 0.0,
Expand Down
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/scaler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Scaler
######

.. automodule:: metatrain.utils.scaler
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions src/metatrain/experimental/nanopet/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ architecture:
scheduler_factor: 0.8
log_interval: 10
checkpoint_interval: 100
scale_targets: true
fixed_composition_weights: {}
per_structure_targets: []
log_mae: False
Expand Down
38 changes: 18 additions & 20 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from math import prod
from pathlib import Path
from typing import Dict, List, Optional, Union
Expand All @@ -18,6 +17,7 @@
from ...utils.additive import ZBL, CompositionModel
from ...utils.data import DatasetInfo, TargetInfo
from ...utils.dtype import dtype_to_str
from ...utils.scaler import Scaler
from .modules.encoder import Encoder
from .modules.nef import (
edge_array_to_nef,
Expand Down Expand Up @@ -126,6 +126,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
self.head_types = self.hypers["heads"]
self.last_layers = torch.nn.ModuleDict()
self.output_shapes: Dict[str, List[int]] = {}
self.key_labels: Dict[str, Labels] = {}
self.component_labels: Dict[str, List[Labels]] = {}
self.property_labels: Dict[str, Labels] = {}
for target_name, target_info in dataset_info.targets.items():
self._add_output(target_name, target_info)

Expand Down Expand Up @@ -158,24 +161,10 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
additive_models.append(ZBL(model_hypers, dataset_info))
self.additive_models = torch.nn.ModuleList(additive_models)

# cache keys, components, properties labels
# scaler: this is also handled by the trainer at training time
self.scaler = Scaler(model_hypers={}, dataset_info=dataset_info)

self.single_label = Labels.single()
self.key_labels = {
output_name: copy.deepcopy(dataset_info.targets[output_name].layout.keys)
for output_name in self.dataset_info.targets.keys()
}
self.component_labels = {
output_name: copy.deepcopy(
dataset_info.targets[output_name].layout.block().components
)
for output_name in self.dataset_info.targets.keys()
}
self.property_labels = {
output_name: copy.deepcopy(
dataset_info.targets[output_name].layout.block().properties
)
for output_name in self.dataset_info.targets.keys()
}

def restart(self, dataset_info: DatasetInfo) -> "NanoPET":
# merge old and new dataset info
Expand All @@ -188,6 +177,7 @@ def restart(self, dataset_info: DatasetInfo) -> "NanoPET":
for key, value in merged_info.targets.items()
if key not in self.dataset_info.targets
}
self.has_new_targets = len(new_targets) > 0

if len(new_atomic_types) > 0:
raise ValueError(
Expand All @@ -200,7 +190,10 @@ def restart(self, dataset_info: DatasetInfo) -> "NanoPET":
self._add_output(target_name, target)

self.dataset_info = merged_info
self.atomic_types = sorted(self.atomic_types)

# restart the composition and scaler models
self.additive_models[0].restart(dataset_info)
self.scaler.restart(dataset_info)

return self

Expand Down Expand Up @@ -465,7 +458,8 @@ def forward(
)

if not self.training:
# at evaluation, we also add the additive contributions
# at evaluation, we also introduce the scaler and additive contributions
return_dict = self.scaler(return_dict)
for additive_model in self.additive_models:
outputs_for_additive_model: Dict[str, ModelOutput] = {}
for name, output in outputs.items():
Expand Down Expand Up @@ -566,3 +560,7 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None:
prod(self.output_shapes[target_name]),
bias=False,
)

self.key_labels[target_name] = target_info.layout.keys
self.component_labels[target_name] = target_info.layout.block().components
self.property_labels[target_name] = target_info.layout.block().properties
3 changes: 3 additions & 0 deletions src/metatrain/experimental/nanopet/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@
"checkpoint_interval": {
"type": "integer"
},
"scale_targets": {
"type": "boolean"
},
"fixed_composition_weights": {
"type": "object",
"patternProperties": {
Expand Down
10 changes: 5 additions & 5 deletions src/metatrain/experimental/nanopet/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def test_regression_train():

expected_output = torch.tensor(
[
[-0.162086367607],
[-0.022639824077],
[0.000784186646],
[0.019549313933],
[0.063824169338],
[-0.016902115196],
[0.100093543530],
[0.038387011737],
[0.097679324448],
[0.118228666484],
]
)

Expand Down
31 changes: 29 additions & 2 deletions src/metatrain/experimental/nanopet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from ...utils.scaler import remove_scale
from .model import NanoPET
from .modules.augmentation import apply_random_augmentations

Expand Down Expand Up @@ -107,6 +108,10 @@ def train(
train_datasets, self.hypers["fixed_composition_weights"]
)

if self.hypers["scale_targets"]:
logger.info("Calculating scaling weights")
model.scaler.train_model(train_datasets, model.additive_models)

if is_distributed:
model = DistributedDataParallel(model, device_ids=[device])

Expand Down Expand Up @@ -207,7 +212,10 @@ def train(
model.parameters(), lr=self.hypers["learning_rate"]
)
if self.optimizer_state_dict is not None:
optimizer.load_state_dict(self.optimizer_state_dict)
# try to load the optimizer state dict, but this is only possible
# if there are no new targets in the model (new parameters)
if not model.has_new_targets:
optimizer.load_state_dict(self.optimizer_state_dict)

# Create a scheduler:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
Expand All @@ -216,7 +224,9 @@ def train(
patience=self.hypers["scheduler_patience"],
)
if self.scheduler_state_dict is not None:
lr_scheduler.load_state_dict(self.scheduler_state_dict)
# same as the optimizer, try to load the scheduler state dict
if not model.has_new_targets:
lr_scheduler.load_state_dict(self.scheduler_state_dict)

# per-atom targets:
per_structure_targets = self.hypers["per_structure_targets"]
Expand Down Expand Up @@ -274,6 +284,9 @@ def systems_and_targets_to_dtype(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
targets = remove_scale(
targets, (model.module if is_distributed else model).scaler
)
systems, targets = systems_and_targets_to_dtype(systems, targets, dtype)
predictions = evaluate_model(
model,
Expand Down Expand Up @@ -330,6 +343,9 @@ def systems_and_targets_to_dtype(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
targets = remove_scale(
targets, (model.module if is_distributed else model).scaler
)
systems = [system.to(dtype=dtype) for system in systems]
targets = {key: value.to(dtype=dtype) for key, value in targets.items()}
predictions = evaluate_model(
Expand Down Expand Up @@ -377,13 +393,24 @@ def systems_and_targets_to_dtype(
}

if epoch == start_epoch:
scaler_scales = (
model.module if is_distributed else model
).scaler.get_scales_dict()
metric_logger = MetricLogger(
log_obj=logger,
dataset_info=(
model.module if is_distributed else model
).dataset_info,
initial_metrics=[finalized_train_info, finalized_val_info],
names=["training", "validation"],
scales={
key: (
scaler_scales[key.split(" ")[0]]
if ("MAE" in key or "RMSE" in key)
else 1.0
)
for key in finalized_train_info.keys()
},
)
if epoch % self.hypers["log_interval"] == 0:
metric_logger.log(
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/soap_bpnn/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ architecture:
scheduler_factor: 0.8
log_interval: 5
checkpoint_interval: 25
scale_targets: true
fixed_composition_weights: {}
per_structure_targets: []
log_mae: False
Expand Down
13 changes: 11 additions & 2 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ...utils.additive import ZBL, CompositionModel
from ...utils.dtype import dtype_to_str
from ...utils.scaler import Scaler


class Identity(torch.nn.Module):
Expand Down Expand Up @@ -297,6 +298,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
additive_models.append(ZBL(model_hypers, dataset_info))
self.additive_models = torch.nn.ModuleList(additive_models)

# scaler: this is also handled by the trainer at training time
self.scaler = Scaler(model_hypers={}, dataset_info=dataset_info)

def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn":
# merge old and new dataset info
merged_info = self.dataset_info.union(dataset_info)
Expand All @@ -308,6 +312,7 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn":
for key, value in merged_info.targets.items()
if key not in self.dataset_info.targets
}
self.has_new_targets = len(new_targets) > 0

if len(new_atomic_types) > 0:
raise ValueError(
Expand All @@ -320,7 +325,10 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn":
self._add_output(target_name, target)

self.dataset_info = merged_info
self.atomic_types = sorted(self.atomic_types)

# restart the composition and scaler models
self.additive_models[0].restart(dataset_info)
self.scaler.restart(dataset_info)

return self

Expand Down Expand Up @@ -409,7 +417,8 @@ def forward(
)

if not self.training:
# at evaluation, we also add the additive contributions
# at evaluation, we also introduce the scaler and additive contributions
return_dict = self.scaler(return_dict)
for additive_model in self.additive_models:
outputs_for_additive_model: Dict[str, ModelOutput] = {}
for name, output in outputs.items():
Expand Down
3 changes: 3 additions & 0 deletions src/metatrain/experimental/soap_bpnn/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@
"checkpoint_interval": {
"type": "integer"
},
"scale_targets": {
"type": "boolean"
},
"fixed_composition_weights": {
"type": "object",
"patternProperties": {
Expand Down
31 changes: 29 additions & 2 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from ...utils.scaler import remove_scale
from ...utils.transfer import (
systems_and_targets_to_device,
systems_and_targets_to_dtype,
Expand Down Expand Up @@ -120,6 +121,10 @@ def train(
train_datasets, self.hypers["fixed_composition_weights"]
)

if self.hypers["scale_targets"]:
logger.info("Calculating scaling weights")
model.scaler.train_model(train_datasets, model.additive_models)

if is_distributed:
model = DistributedDataParallel(model, device_ids=[device])

Expand Down Expand Up @@ -220,7 +225,10 @@ def train(
model.parameters(), lr=self.hypers["learning_rate"]
)
if self.optimizer_state_dict is not None:
optimizer.load_state_dict(self.optimizer_state_dict)
# try to load the optimizer state dict, but this is only possible
# if there are no new targets in the model (new parameters)
if not model.has_new_targets:
optimizer.load_state_dict(self.optimizer_state_dict)

# Create a scheduler:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
Expand All @@ -230,7 +238,9 @@ def train(
threshold=0.001,
)
if self.scheduler_state_dict is not None:
lr_scheduler.load_state_dict(self.scheduler_state_dict)
# same as the optimizer, try to load the scheduler state dict
if not model.has_new_targets:
lr_scheduler.load_state_dict(self.scheduler_state_dict)

# per-atom targets:
per_structure_targets = self.hypers["per_structure_targets"]
Expand Down Expand Up @@ -269,6 +279,9 @@ def train(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
targets = remove_scale(
targets, (model.module if is_distributed else model).scaler
)
systems, targets = systems_and_targets_to_dtype(systems, targets, dtype)
predictions = evaluate_model(
model,
Expand Down Expand Up @@ -322,6 +335,9 @@ def train(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
targets = remove_scale(
targets, (model.module if is_distributed else model).scaler
)
systems, targets = systems_and_targets_to_dtype(systems, targets, dtype)
predictions = evaluate_model(
model,
Expand Down Expand Up @@ -365,13 +381,24 @@ def train(
finalized_val_info = {"loss": val_loss, **finalized_val_info}

if epoch == start_epoch:
scaler_scales = (
model.module if is_distributed else model
).scaler.get_scales_dict()
metric_logger = MetricLogger(
log_obj=logger,
dataset_info=(
model.module if is_distributed else model
).dataset_info,
initial_metrics=[finalized_train_info, finalized_val_info],
names=["training", "validation"],
scales={
key: (
scaler_scales[key.split(" ")[0]]
if ("MAE" in key or "RMSE" in key)
else 1.0
)
for key in finalized_train_info.keys()
},
)
if epoch % self.hypers["log_interval"] == 0:
metric_logger.log(
Expand Down
Loading

0 comments on commit 112307a

Please sign in to comment.