From d266deb0c527279528004537eaceee6b264adabd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 7 May 2024 07:16:06 -0700 Subject: [PATCH] address https://github.com/lucidrains/magvit2-pytorch/issues/18 --- magvit2_pytorch/trainer.py | 6 +++++- magvit2_pytorch/version.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/magvit2_pytorch/trainer.py b/magvit2_pytorch/trainer.py index 7ac9414..cd0f054 100644 --- a/magvit2_pytorch/trainer.py +++ b/magvit2_pytorch/trainer.py @@ -267,6 +267,10 @@ def is_main(self): def is_local_main(self): return self.accelerator.is_local_main_process + @property + def unwrapped_model(self): + return self.accelerator.unwrap_model(self.model) + def wait(self): return self.accelerator.wait_for_everyone() @@ -461,7 +465,7 @@ def valid_step( valid_video = valid_video.to(self.device) with self.accelerator.autocast(): - loss, _ = self.model(valid_video, return_recon_loss_only = True) + loss, _ = self.unwrapped_model(valid_video, return_recon_loss_only = True) ema_loss, ema_recon_video = self.ema_model(valid_video, return_recon_loss_only = True) recon_loss += loss / self.grad_accum_every diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index a987347..908c0bb 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.4.2' +__version__ = '0.4.3'