diff --git a/comfy/cldm/mmdit.py b/comfy/cldm/mmdit.py new file mode 100644 index 00000000000..6e72474ce90 --- /dev/null +++ b/comfy/cldm/mmdit.py @@ -0,0 +1,91 @@ +import torch +from typing import Dict, Optional +import comfy.ldm.modules.diffusionmodules.mmdit +import comfy.latent_formats + +class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): + def __init__( + self, + num_blocks = None, + dtype = None, + device = None, + operations = None, + **kwargs, + ): + super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs) + # controlnet_blocks + self.controlnet_blocks = torch.nn.ModuleList([]) + for _ in range(len(self.joint_blocks)): + self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)) + + self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed( + None, + self.patch_size, + self.in_channels, + self.hidden_size, + bias=True, + strict_img_size=False, + dtype=dtype, + device=device, + operations=operations + ) + + self.latent_format = comfy.latent_formats.SD3() + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + hint = None, + ) -> torch.Tensor: + + #weird sd3 controlnet specific stuff + hint = hint * self.latent_format.scale_factor # self.latent_format.process_in(hint) + y = torch.zeros_like(y) + + + if self.context_processor is not None: + context = self.context_processor(context) + + hw = x.shape[-2:] + x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device) + x += self.pos_embed_input(hint) + + c = self.t_embedder(timesteps, dtype=x.dtype) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) + c = c + y + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + ( + repeat(self.register, "1 ... -> b ...", b=x.shape[0]), + default(context, torch.Tensor([]).type_as(x)), + ), + 1, + ) + + output = [] + + blocks = len(self.joint_blocks) + for i in range(blocks): + context, x = self.joint_blocks[i]( + context, + x, + c=c, + use_checkpoint=self.use_checkpoint, + ) + + out = self.controlnet_blocks[i](x) + count = self.depth // blocks + if i == blocks - 1: + count -= 1 + for j in range(count): + output.append(out) + + return {"output": output} diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f50df68357d..9202c31944f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -11,6 +11,7 @@ import comfy.cldm.cldm import comfy.t2i_adapter.adapter import comfy.ldm.cascade.controlnet +import comfy.cldm.mmdit def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -94,13 +95,17 @@ def control_merge(self, control, control_prev, output_dtype): for key in control: control_output = control[key] + applied_to = set() for i in range(len(control_output)): x = control_output[i] if x is not None: if self.global_average_pooling: x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) - x *= self.strength + if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once + applied_to.add(x) + x *= self.strength + if x.dtype != output_dtype: x = x.to(output_dtype) @@ -120,17 +125,18 @@ def control_merge(self, control, control_prev, output_dtype): if o[i].shape[0] < prev_val.shape[0]: o[i] = prev_val + o[i] else: - o[i] += prev_val + o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue return out class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model self.load_device = load_device if control_model is not None: self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype @@ -308,6 +314,37 @@ def get_models(self): def inference_memory_requirements(self, dtype): return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) +def load_controlnet_mmdit(sd): + new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") + model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True) + num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') + for k in sd: + new_sd[k] = sd[k] + + supported_inference_dtypes = model_config.supported_inference_dtypes + + controlnet_config = model_config.unet_config + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + load_device = comfy.model_management.get_torch_device() + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init + + control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config) + missing, unexpected = control_model.load_state_dict(new_sd, strict=False) + + if len(missing) > 0: + logging.warning("missing controlnet keys: {}".format(missing)) + + if len(unexpected) > 0: + logging.debug("unexpected controlnet keys: {}".format(unexpected)) + + control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + return control + + def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if "lora_controlnet" in controlnet_data: @@ -360,6 +397,8 @@ def load_controlnet(ckpt_path, model=None): if len(leftover_keys) > 0: logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd + elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format + return load_controlnet_mmdit(controlnet_data) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 20d3a321a02..927451534d7 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -745,6 +745,8 @@ def __init__( qkv_bias: bool = True, context_processor_layers = None, context_size = 4096, + num_blocks = None, + final_layer = True, dtype = None, #TODO device = None, operations = None, @@ -766,7 +768,10 @@ def __init__( # apply magic --> this defines a head_size of 64 self.hidden_size = 64 * depth num_heads = depth + if num_blocks is None: + num_blocks = depth + self.depth = depth self.num_heads = num_heads self.x_embedder = PatchEmbed( @@ -821,7 +826,7 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, - pre_only=i == depth - 1, + pre_only=(i == num_blocks - 1) and final_layer, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, @@ -830,11 +835,12 @@ def __init__( device=device, operations=operations ) - for i in range(depth) + for i in range(num_blocks) ] ) - self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + if final_layer: + self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) if compile_core: assert False @@ -893,6 +899,7 @@ def forward_core_with_concat( x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None, + control = None, ) -> torch.Tensor: if self.register_length > 0: context = torch.cat( @@ -905,13 +912,20 @@ def forward_core_with_concat( # context is B, L', D # x is B, L, D - for block in self.joint_blocks: - context, x = block( + blocks = len(self.joint_blocks) + for i in range(blocks): + context, x = self.joint_blocks[i]( context, x, c=c_mod, use_checkpoint=self.use_checkpoint, ) + if control is not None: + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + x += add x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) return x @@ -922,6 +936,7 @@ def forward( t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, + control = None, ) -> torch.Tensor: """ Forward pass of DiT. @@ -943,7 +958,7 @@ def forward( if context is not None: context = self.context_embedder(context) - x = self.forward_core_with_concat(x, c, context) + x = self.forward_core_with_concat(x, c, context, control) x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) return x[:,:,:hw[-2],:hw[-1]] @@ -956,7 +971,8 @@ def forward( timesteps: torch.Tensor, context: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + control = None, **kwargs, ) -> torch.Tensor: - return super().forward(x, timesteps, context=context, y=y) + return super().forward(x, timesteps, context=context, y=y, control=control) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e09dd381ad9..0b678480f19 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -41,7 +41,9 @@ def detect_unet_config(state_dict, key_prefix): unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1] patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2] unet_config["patch_size"] = patch_size - unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size) + final_layer = '{}final_layer.linear.weight'.format(key_prefix) + if final_layer in state_dict: + unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size) unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64 unet_config["input_size"] = None @@ -435,10 +437,11 @@ def model_config_from_diffusers_unet(state_dict): return None def convert_diffusers_mmdit(state_dict, output_prefix=""): - depth = count_blocks(state_dict, 'transformer_blocks.{}.') - if depth > 0: + num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') + if num_blocks > 0: + depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 out_sd = {} - sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix) + sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) for k in sd_map: weight = state_dict.get(k, None) if weight is not None: diff --git a/comfy/utils.py b/comfy/utils.py index ed6c58a64e7..48618e07616 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -298,7 +298,8 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): key_map = {} depth = mmdit_config.get("depth", 0) - for i in range(depth): + num_blocks = mmdit_config.get("num_blocks", depth) + for i in range(num_blocks): block_from = "transformer_blocks.{}".format(i) block_to = "{}joint_blocks.{}".format(output_prefix, i)