Skip to content

Commit

Permalink
Fix some controlnets OOMing when loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 25, 2024
1 parent 6ab1e6f commit 9230f65
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ def controlnet_config(sd):
else:
operations = comfy.ops.disable_weight_init

return model_config, operations, load_device, unet_dtype, manual_cast_dtype
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device

def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
Expand All @@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):

def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]

control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)

latent_format = comfy.latent_formats.SD3()
Expand All @@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):


def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)

control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)

latent_format = comfy.latent_formats.SDXL()
Expand All @@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
return control

def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
Expand Down Expand Up @@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None):
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
Expand Down

0 comments on commit 9230f65

Please sign in to comment.