From a9e459c2a466c23b56ba8d62d8f3aeffdfb971fc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 29 Sep 2024 11:13:53 -0400 Subject: [PATCH] Use torch.nn.functional.linear in RGB preview code. Add an optional bias to the latent RGB preview code. --- comfy/latent_formats.py | 1 + latent_preview.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index ee19faeae74..78397d75a65 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -4,6 +4,7 @@ class LatentFormat: scale_factor = 1.0 latent_channels = 4 latent_rgb_factors = None + latent_rgb_factors_bias = None taesd_decoder_name = None def process_in(self, latent): diff --git a/latent_preview.py b/latent_preview.py index e14c72ce4d0..ae9211a27ad 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -36,12 +36,20 @@ def decode_latent_to_preview(self, x0): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self, latent_rgb_factors): - self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) + self.latent_rgb_factors_bias = None + if latent_rgb_factors_bias is not None: + self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") def decode_latent_to_preview(self, x0): self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) - latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors + if self.latent_rgb_factors_bias is not None: + self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) + + latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias) + # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors + return preview_to_image(latent_image) @@ -71,7 +79,7 @@ def get_previewer(device, latent_format): if previewer is None: if latent_format.latent_rgb_factors is not None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) return previewer def prepare_callback(model, steps, x0_output_dict=None):