Skip to content

Commit

Permalink
Apply Filippo's suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
abmazitov committed Nov 25, 2024
1 parent 739388f commit 18e1a90
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, train_hypers):
self.pet_dir = None
self.pet_trainer_state = None
self.epoch = None
self.best_loss = None
self.best_metric = None
self.best_model_state_dict = None

def train(
Expand Down Expand Up @@ -386,8 +386,8 @@ def train(
)
TIME_TRAINING_STARTED = time.time()
last_elapsed_time = 0
if self.best_loss is None:
self.best_loss = float("inf")
if self.best_metric is None:
self.best_metric = float("inf")
start_epoch = 1 if self.epoch is None else self.epoch + 1
for epoch in range(start_epoch, start_epoch + FITTING_SCHEME.EPOCH_NUM):
pet_model.train(True)
Expand Down Expand Up @@ -667,8 +667,8 @@ def save_model(model_name, model_keeper):
summary += f"{energies_rmse_model_keeper.best_error} "
summary += f"at epoch {energies_rmse_model_keeper.best_epoch}\n"

if energies_mae_model_keeper.best_error < self.best_loss:
self.best_loss = energies_mae_model_keeper.best_error
if energies_mae_model_keeper.best_error < self.best_metric:
self.best_metric = energies_mae_model_keeper.best_error
self.best_model_state_dict = (
energies_mae_model_keeper.best_model.state_dict()
)
Expand All @@ -684,8 +684,8 @@ def save_model(model_name, model_keeper):
)
summary += f"at epoch {forces_rmse_model_keeper.best_epoch}\n"

if forces_mae_model_keeper.best_error < self.best_loss:
self.best_loss = forces_mae_model_keeper.best_error
if forces_mae_model_keeper.best_error < self.best_metric:
self.best_metric = forces_mae_model_keeper.best_error
self.best_model_state_dict = (
forces_mae_model_keeper.best_model.state_dict()
)
Expand All @@ -709,8 +709,8 @@ def save_model(model_name, model_keeper):
)
summary += f"{multiplication_rmse_model_keeper.best_epoch}\n"

if multiplication_mae_model_keeper.best_error < self.best_loss:
self.best_loss = multiplication_mae_model_keeper.best_error
if multiplication_mae_model_keeper.best_error < self.best_metric:
self.best_metric = multiplication_mae_model_keeper.best_error
self.best_model_state_dict = (
multiplication_mae_model_keeper.best_model.state_dict()
)
Expand Down Expand Up @@ -757,7 +757,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
"trainer_state_dict": trainer_state_dict,
"model_state_dict": last_model_state_dict,
"best_model_state_dict": self.best_model_state_dict,
"best_loss": self.best_loss,
"best_metric": self.best_metric,
"hypers": self.hypers,
"epoch": self.epoch,
"dataset_info": model.dataset_info,
Expand All @@ -768,7 +768,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
"trainer_state_dict": None,
"model_state_dict": self.best_model_state_dict,
"best_model_state_dict": None,
"best_loss": None,
"best_metric": None,
"hypers": self.hypers,
"epoch": None,
"dataset_info": model.dataset_info,
Expand Down Expand Up @@ -796,7 +796,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
trainer.epoch = checkpoint["epoch"]
old_fitting_scheme = checkpoint["hypers"]["FITTING_SCHEME"]
new_fitting_scheme = train_hypers
best_loss = checkpoint["best_loss"]
best_metric = checkpoint["best_metric"]
best_model_state_dict = checkpoint["best_model_state_dict"]
# The following code is not reached in the current implementation
# because the check for the train targets is done earlier in the
Expand All @@ -812,8 +812,8 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
"The `best model` and the `best loss` parts of the checkpoint "
"will be reset to avoid inconsistencies."
)
best_loss = None
best_metric = None
best_model_state_dict = None
trainer.best_loss = best_loss
trainer.best_metric = best_metric
trainer.best_model_state_dict = best_model_state_dict
return trainer

0 comments on commit 18e1a90

Please sign in to comment.