diff --git a/train_pl.py b/train_pl.py index e12d3b5..d967be5 100644 --- a/train_pl.py +++ b/train_pl.py @@ -121,7 +121,7 @@ def on_save_checkpoint(self, checkpoint): epoch = self.trainer.current_epoch step = self.trainer.global_step checkpoint = { - "model": self.model.module.state_dict(), + "model": self.model.state_dict(), "ema": self.ema.state_dict(), } torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt") @@ -221,7 +221,7 @@ def main(args): # Trainer trainer = Trainer( accelerator="gpu", - devices=[3], # Specify GPU ids + # devices=[3], # Specify GPU ids strategy="auto", max_epochs=num_train_epochs, logger=tb_logger, diff --git a/train_with_img_pl.py b/train_with_img_pl.py index 3b319c2..ff6da7a 100644 --- a/train_with_img_pl.py +++ b/train_with_img_pl.py @@ -131,7 +131,7 @@ def on_save_checkpoint(self, checkpoint): epoch = self.trainer.current_epoch step = self.trainer.global_step checkpoint = { - "model": self.model.module.state_dict(), + "model": self.model.state_dict(), "ema": self.ema.state_dict(), } torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt") @@ -231,7 +231,7 @@ def main(args): # Trainer trainer = Trainer( accelerator="gpu", - devices=[3], # Specify GPU ids + # devices=[3], # Specify GPU ids strategy="auto", max_epochs=num_train_epochs, logger=tb_logger,