Skip to content

Commit

Permalink
Add model_options to load_controlnet function.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Sep 19, 2024
1 parent de8e8e3 commit ad66f7c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
56 changes: 31 additions & 25 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def forward(self, input):


class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
Expand Down Expand Up @@ -392,19 +392,22 @@ def get_models(self):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)

def controlnet_config(sd):
def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)

supported_inference_dtypes = model_config.supported_inference_dtypes

controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes))
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init

operations = model_options.get("custom_operations", None)
if operations is None:
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init

offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
Expand All @@ -419,9 +422,9 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model

def load_controlnet_mmdit(sd):
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
Expand All @@ -440,8 +443,8 @@ def load_controlnet_mmdit(sd):
return control


def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)

control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
Expand All @@ -451,17 +454,17 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control

def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control

def load_controlnet_flux_instantx(sd):
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]

Expand All @@ -487,13 +490,13 @@ def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})


def load_controlnet(ckpt_path, model=None):
def load_controlnet(ckpt_path, model=None, model_options={}):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)

if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
return ControlLora(controlnet_data, model_options=model_options)

controlnet_config = None
supported_inference_dtypes = None
Expand Down Expand Up @@ -550,13 +553,13 @@ def load_controlnet(ckpt_path, model=None):
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data)
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data) #SD3 diffusers controlnet
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data)
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
Expand All @@ -568,7 +571,7 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return net
Expand All @@ -587,7 +590,10 @@ def load_controlnet(ckpt_path, model=None):
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config["dtype"] = unet_dtype
if "custom_operations" in model_options:
controlnet_config["operations"] = model_options["custom_operations"]
if "dtype" in model_options:
controlnet_config["dtype"] = model_options["dtype"]
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
Expand Down Expand Up @@ -685,7 +691,7 @@ def copy(self):
self.copy_to(c)
return c

def load_t2i_adapter(t2i_data):
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

Expand Down
2 changes: 1 addition & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse

manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
Expand Down

0 comments on commit ad66f7c

Please sign in to comment.