From ad66f7c7d8c3dc2985f2fba4e40b503cb45be03a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 19 Sep 2024 05:01:00 -0400 Subject: [PATCH] Add model_options to load_controlnet function. --- comfy/controlnet.py | 56 +++++++++++++++++++++++++-------------------- comfy/sd.py | 2 +- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 1ea00eccf72..4bd075e9d0d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -335,7 +335,7 @@ def forward(self, input): class ControlLora(ControlNet): - def __init__(self, control_weights, global_average_pooling=False, device=None): + def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options ControlBase.__init__(self, device) self.control_weights = control_weights self.global_average_pooling = global_average_pooling @@ -392,19 +392,22 @@ def get_models(self): def inference_memory_requirements(self, dtype): return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) -def controlnet_config(sd): +def controlnet_config(sd, model_options={}): model_config = comfy.model_detection.model_config_from_unet(sd, "", True) supported_inference_dtypes = model_config.supported_inference_dtypes controlnet_config = model_config.unet_config - unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)) load_device = comfy.model_management.get_torch_device() manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init + + operations = model_options.get("custom_operations", None) + if operations is None: + if manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init offload_device = comfy.model_management.unet_offload_device() return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device @@ -419,9 +422,9 @@ def controlnet_load_state_dict(control_model, sd): logging.debug("unexpected controlnet keys: {}".format(unexpected)) return control_model -def load_controlnet_mmdit(sd): +def load_controlnet_mmdit(sd, model_options={}): new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") - model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options) num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') for k in sd: new_sd[k] = sd[k] @@ -440,8 +443,8 @@ def load_controlnet_mmdit(sd): return control -def load_controlnet_hunyuandit(controlnet_data): - model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data) +def load_controlnet_hunyuandit(controlnet_data, model_options={}): + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options) control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype) control_model = controlnet_load_state_dict(control_model, controlnet_data) @@ -451,17 +454,17 @@ def load_controlnet_hunyuandit(controlnet_data): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) return control -def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False): - model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) +def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}): + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control -def load_controlnet_flux_instantx(sd): +def load_controlnet_flux_instantx(sd, model_options={}): new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") - model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options) for k in sd: new_sd[k] = sd[k] @@ -487,13 +490,13 @@ def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) -def load_controlnet(ckpt_path, model=None): +def load_controlnet(ckpt_path, model=None, model_options={}): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT - return load_controlnet_hunyuandit(controlnet_data) + return load_controlnet_hunyuandit(controlnet_data, model_options=model_options) if "lora_controlnet" in controlnet_data: - return ControlLora(controlnet_data) + return ControlLora(controlnet_data, model_options=model_options) controlnet_config = None supported_inference_dtypes = None @@ -550,13 +553,13 @@ def load_controlnet(ckpt_path, model=None): controlnet_data = new_sd elif "controlnet_blocks.0.weight" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: - return load_controlnet_flux_xlabs_mistoline(controlnet_data) + return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options) elif "pos_embed_input.proj.weight" in controlnet_data: - return load_controlnet_mmdit(controlnet_data) #SD3 diffusers controlnet + return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet elif "controlnet_x_embedder.weight" in controlnet_data: - return load_controlnet_flux_instantx(controlnet_data) + return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux - return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True) + return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) pth_key = 'control_model.zero_convs.0.0.weight' pth = False @@ -568,7 +571,7 @@ def load_controlnet(ckpt_path, model=None): elif key in controlnet_data: prefix = "" else: - net = load_t2i_adapter(controlnet_data) + net = load_t2i_adapter(controlnet_data, model_options=model_options) if net is None: logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) return net @@ -587,7 +590,10 @@ def load_controlnet(ckpt_path, model=None): manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) if manual_cast_dtype is not None: controlnet_config["operations"] = comfy.ops.manual_cast - controlnet_config["dtype"] = unet_dtype + if "custom_operations" in model_options: + controlnet_config["operations"] = model_options["custom_operations"] + if "dtype" in model_options: + controlnet_config["dtype"] = model_options["dtype"] controlnet_config["device"] = comfy.model_management.unet_offload_device() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] @@ -685,7 +691,7 @@ def copy(self): self.copy_to(c) return c -def load_t2i_adapter(t2i_data): +def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options compression_ratio = 8 upscale_algorithm = 'nearest-exact' diff --git a/comfy/sd.py b/comfy/sd.py index 8c5b058ceae..99859d24175 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -645,7 +645,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", None) + model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "")