diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index 560e94c25..24c77629d 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -610,9 +610,11 @@ def train( else: lora_state_dict = None last_model_checkpoint = { + "architecture_name": "experimental.pet", "trainer_state_dict": trainer_state_dict, "model_state_dict": last_model_state_dict, "best_model_state_dict": self.best_model_state_dict, + "best_metric": self.best_metric, "hypers": self.hypers, "epoch": self.epoch, "dataset_info": model.dataset_info,