From 4c82741b545c6cedcfa397034f56ce1377b3675a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Nov 2024 11:31:25 -0500 Subject: [PATCH] Support official SD3.5 Controlnets. --- comfy/cldm/dit_embedder.py | 122 +++++++++++++++++++++++++++++++++++++ comfy/controlnet.py | 91 ++++++++++++++++++++++++++- 2 files changed, 210 insertions(+), 3 deletions(-) create mode 100644 comfy/cldm/dit_embedder.py diff --git a/comfy/cldm/dit_embedder.py b/comfy/cldm/dit_embedder.py new file mode 100644 index 00000000000..e9cdd49910b --- /dev/null +++ b/comfy/cldm/dit_embedder.py @@ -0,0 +1,122 @@ +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor + +from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch + + +class ControlNetEmbedder(nn.Module): + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + attention_head_dim: int, + num_attention_heads: int, + adm_in_channels: int, + num_layers: int, + main_model_double: int, + double_y_emb: bool, + device: torch.device, + dtype: torch.dtype, + pos_embed_max_size: Optional[int] = None, + operations = None, + ): + super().__init__() + self.main_model_double = main_model_double + self.dtype = dtype + self.hidden_size = num_attention_heads * attention_head_dim + self.patch_size = patch_size + self.x_embedder = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=self.hidden_size, + strict_img_size=pos_embed_max_size is None, + device=device, + dtype=dtype, + operations=operations, + ) + + self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations) + + self.double_y_emb = double_y_emb + if self.double_y_emb: + self.orig_y_embedder = VectorEmbedder( + adm_in_channels, self.hidden_size, dtype, device, operations=operations + ) + self.y_embedder = VectorEmbedder( + self.hidden_size, self.hidden_size, dtype, device, operations=operations + ) + else: + self.y_embedder = VectorEmbedder( + adm_in_channels, self.hidden_size, dtype, device, operations=operations + ) + + self.transformer_blocks = nn.ModuleList( + DismantledBlock( + hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ) + + # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features + # TODO double check this logic when 8b + self.use_y_embedder = True + + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + self.controlnet_blocks.append(controlnet_block) + + self.pos_embed_input = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=self.hidden_size, + strict_img_size=False, + device=device, + dtype=dtype, + operations=operations, + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + hint = None, + ) -> Tuple[Tensor, List[Tensor]]: + x_shape = list(x.shape) + x = self.x_embedder(x) + if not self.double_y_emb: + h = (x_shape[-2] + 1) // self.patch_size + w = (x_shape[-1] + 1) // self.patch_size + x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device) + c = self.t_embedder(timesteps, dtype=x.dtype) + if y is not None and self.y_embedder is not None: + if self.double_y_emb: + y = self.orig_y_embedder(y) + y = self.y_embedder(y) + c = c + y + + x = x + self.pos_embed_input(hint) + + block_out = () + + repeat = math.ceil(self.main_model_double / len(self.transformer_blocks)) + for i in range(len(self.transformer_blocks)): + out = self.transformer_blocks[i](x, c) + if not self.double_y_emb: + x = out + block_out += (self.controlnet_blocks[i](out),) * repeat + + return {"output": block_out} diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d2744e427da..cf9f894cae1 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -35,7 +35,7 @@ import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet - +import comfy.cldm.dit_embedder def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -78,6 +78,7 @@ def __init__(self): self.concat_mask = False self.extra_concat_orig = [] self.extra_concat = None + self.preprocess_image = lambda a: a def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint @@ -129,6 +130,7 @@ def copy_to(self, c): c.strength_type = self.strength_type c.concat_mask = self.concat_mask c.extra_concat_orig = self.extra_concat_orig.copy() + c.preprocess_image = self.preprocess_image def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -181,7 +183,7 @@ def set_extra_arg(self, argument, value=None): class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a): super().__init__() self.control_model = control_model self.load_device = load_device @@ -196,6 +198,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression self.extra_conds += extra_conds self.strength_type = strength_type self.concat_mask = concat_mask + self.preprocess_image = preprocess_image def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -224,6 +227,7 @@ def get_control(self, x_noisy, t, cond, batched_number): if self.latent_format is not None: raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + self.cond_hint = self.preprocess_image(self.cond_hint) if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) @@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd): logging.debug("unexpected controlnet keys: {}".format(unexpected)) return control_model + def load_controlnet_mmdit(sd, model_options={}): new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options) @@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}): return control +class ControlNetSD35(ControlNet): + def pre_run(self, model, percent_to_timestep_function): + if self.control_model.double_y_emb: + missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False) + else: + missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False) + super().pre_run(model, percent_to_timestep_function) + + def copy(self): + c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped + self.copy_to(c) + return c + +def load_controlnet_sd35(sd, model_options={}): + control_type = -1 + if "control_type" in sd: + control_type = round(sd.pop("control_type").item()) + + # blur_cnet = control_type == 0 + canny_cnet = control_type == 1 + depth_cnet = control_type == 2 + + print(control_type, canny_cnet, depth_cnet) + new_sd = {} + for k in comfy.utils.MMDIT_MAP_BASIC: + if k[1] in sd: + new_sd[k[0]] = sd.pop(k[1]) + for k in sd: + new_sd[k] = sd[k] + sd = new_sd + + y_emb_shape = sd["y_embedder.mlp.0.weight"].shape + depth = y_emb_shape[0] // 64 + hidden_size = 64 * depth + num_heads = depth + head_dim = hidden_size // num_heads + num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.') + + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.unet_offload_device() + unet_dtype = comfy.model_management.unet_dtype(model_params=-1) + + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) + + control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None, + patch_size=2, + in_chans=16, + num_layers=num_blocks, + main_model_double=depth, + double_y_emb=y_emb_shape[0] == y_emb_shape[1], + attention_head_dim=head_dim, + num_attention_heads=num_heads, + adm_in_channels=2048, + device=offload_device, + dtype=unet_dtype, + operations=operations) + + control_model = controlnet_load_state_dict(control_model, sd) + + latent_format = comfy.latent_formats.SD3() + preprocess_image = lambda a: a + if canny_cnet: + preprocess_image = lambda a: (a * 255 * 0.5 + 0.5) + elif depth_cnet: + preprocess_image = lambda a: 1.0 - a + + control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image) + return control + + + def load_controlnet_hunyuandit(controlnet_data, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options) @@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options) elif "pos_embed_input.proj.weight" in controlnet_data: - return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet + if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data: + return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format + else: + return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux