diff --git a/magvit2_pytorch/trainer.py b/magvit2_pytorch/trainer.py index 7ac9414..cd0f054 100644 --- a/magvit2_pytorch/trainer.py +++ b/magvit2_pytorch/trainer.py @@ -267,6 +267,10 @@ def is_main(self): def is_local_main(self): return self.accelerator.is_local_main_process + @property + def unwrapped_model(self): + return self.accelerator.unwrap_model(self.model) + def wait(self): return self.accelerator.wait_for_everyone() @@ -461,7 +465,7 @@ def valid_step( valid_video = valid_video.to(self.device) with self.accelerator.autocast(): - loss, _ = self.model(valid_video, return_recon_loss_only = True) + loss, _ = self.unwrapped_model(valid_video, return_recon_loss_only = True) ema_loss, ema_recon_video = self.ema_model(valid_video, return_recon_loss_only = True) recon_loss += loss / self.grad_accum_every diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index a987347..908c0bb 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.4.2' +__version__ = '0.4.3'