Skip to content

Commit

Permalink
Fix final checkpoint bug (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Dec 17, 2024
1 parent 8ae4088 commit 7ca524b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/metatrain/experimental/nanopet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def systems_and_targets_to_dtype(
if self.best_loss is None:
self.best_loss = float("inf")
logger.info("Starting training")
epoch = start_epoch
for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]):
if is_distributed:
sampler.set_epoch(epoch)
Expand Down Expand Up @@ -314,8 +315,6 @@ def systems_and_targets_to_dtype(
if self.hypers["log_mae"]:
train_mae_calculator.update(predictions, targets)

# count += 1

finalized_train_info = train_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
Expand Down Expand Up @@ -461,6 +460,7 @@ def systems_and_targets_to_dtype(
)

# prepare for the checkpoint that will be saved outside the function
self.epoch = epoch
self.optimizer_state_dict = optimizer.state_dict()
self.scheduler_state_dict = lr_scheduler.state_dict()

Expand Down
2 changes: 2 additions & 0 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def train(
if self.best_loss is None:
self.best_loss = float("inf")
logger.info("Starting training")
epoch = start_epoch
for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]):
if is_distributed:
sampler.set_epoch(epoch)
Expand Down Expand Up @@ -449,6 +450,7 @@ def train(
)

# prepare for the checkpoint that will be saved outside the function
self.epoch = epoch
self.optimizer_state_dict = optimizer.state_dict()
self.scheduler_state_dict = lr_scheduler.state_dict()

Expand Down

0 comments on commit 7ca524b

Please sign in to comment.