Skip to content

Commit

Permalink
Fixed PET.restart() method
Browse files Browse the repository at this point in the history
  • Loading branch information
abmazitov committed Nov 24, 2024
1 parent d6e2f23 commit b186ecb
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,30 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
self.additive_models = torch.nn.ModuleList(additive_models)

def restart(self, dataset_info: DatasetInfo) -> "PET":
if dataset_info != self.dataset_info:
self.fine_tuning_mode = True
merged_info = self.dataset_info.union(dataset_info)
new_atomic_types = [
at for at in merged_info.atomic_types if at not in self.atomic_types
]
new_targets = {
key: value
for key, value in merged_info.targets.items()
if key not in self.dataset_info.targets
}

if len(new_atomic_types) > 0:
raise ValueError(
f"New atomic types found in the dataset: {new_atomic_types}. "
"The PET model does not support adding new atomic types."
)

if len(new_targets) > 0:
raise ValueError(
"PET cannot be restarted with different dataset information"
f"New targets found in the training options: {new_targets}. "
"The PET model does not support adding new training targets."
)

self.dataset_info = merged_info
self.atomic_types = sorted(self.atomic_types)
return self

def set_trained_model(self, trained_model: RawPET) -> None:
Expand Down

0 comments on commit b186ecb

Please sign in to comment.