Skip to content

Commit

Permalink
Support official SD3.5 Controlnets.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 26, 2024
1 parent 15c39ea commit 4c82741
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 3 deletions.
122 changes: 122 additions & 0 deletions comfy/cldm/dit_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import math
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor

from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch


class ControlNetEmbedder(nn.Module):

def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)

self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)

self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)

self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)

# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True

self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)

self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)

def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y

x = x + self.pos_embed_input(hint)

block_out = ()

repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat

return {"output": block_out}
91 changes: 88 additions & 3 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet

import comfy.cldm.dit_embedder

def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(self):
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
self.preprocess_image = lambda a: a

def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
Expand Down Expand Up @@ -129,6 +130,7 @@ def copy_to(self, c):
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
c.preprocess_image = self.preprocess_image

def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
Expand Down Expand Up @@ -181,7 +183,7 @@ def set_extra_arg(self, argument, value=None):


class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__()
self.control_model = control_model
self.load_device = load_device
Expand All @@ -196,6 +198,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
self.preprocess_image = preprocess_image

def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
Expand Down Expand Up @@ -224,6 +227,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
Expand Down Expand Up @@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model


def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
Expand All @@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}):
return control


class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)

def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c

def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())

# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2

print(control_type, canny_cnet, depth_cnet)
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd

y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')

load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)

manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)

operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)

control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)

control_model = controlnet_load_state_dict(control_model, sd)

latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a

control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control



def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)

Expand Down Expand Up @@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
Expand Down

0 comments on commit 4c82741

Please sign in to comment.