Skip to content

Commit

Permalink
Merge pull request #114 from zhang-haojie/main
Browse files Browse the repository at this point in the history
Fix pytorch lightning training
  • Loading branch information
maxin-cn authored Aug 22, 2024
2 parents 18fc772 + 8041536 commit 1b74d1e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions train_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def on_save_checkpoint(self, checkpoint):
epoch = self.trainer.current_epoch
step = self.trainer.global_step
checkpoint = {
"model": self.model.module.state_dict(),
"model": self.model.state_dict(),
"ema": self.ema.state_dict(),
}
torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt")
Expand Down Expand Up @@ -221,7 +221,7 @@ def main(args):
# Trainer
trainer = Trainer(
accelerator="gpu",
devices=[3], # Specify GPU ids
# devices=[3], # Specify GPU ids
strategy="auto",
max_epochs=num_train_epochs,
logger=tb_logger,
Expand Down
4 changes: 2 additions & 2 deletions train_with_img_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def on_save_checkpoint(self, checkpoint):
epoch = self.trainer.current_epoch
step = self.trainer.global_step
checkpoint = {
"model": self.model.module.state_dict(),
"model": self.model.state_dict(),
"ema": self.ema.state_dict(),
}
torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt")
Expand Down Expand Up @@ -231,7 +231,7 @@ def main(args):
# Trainer
trainer = Trainer(
accelerator="gpu",
devices=[3], # Specify GPU ids
# devices=[3], # Specify GPU ids
strategy="auto",
max_epochs=num_train_epochs,
logger=tb_logger,
Expand Down

0 comments on commit 1b74d1e

Please sign in to comment.