From 6a6093d72609b107a6db7db2719bdbaf9e3074b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 4 May 2024 15:48:40 +0300 Subject: [PATCH] Fix image's upcasting before `vae.encode()` when using `fp16` --- .../ledits_pp/pipeline_leditspp_stable_diffusion_xl.py | 1 - .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py index cfab70926a4a..5ea7c2c14551 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -1419,7 +1419,6 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode=" if needs_upcasting: image = image.float() self.upcast_vae() - image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) x0 = self.vae.encode(image).latent_dist.mode() x0 = x0.to(dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 5e7be370be01..d9380020b329 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -525,8 +525,8 @@ def prepare_image_latents( # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: + image = image.float() self.upcast_vae() - image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")