diff --git a/comfy/sd.py b/comfy/sd.py index 4ba00a62ce5..e343e1fa8eb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -563,24 +563,32 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o return (model_patcher, clip, vae, clipvision) -def load_unet_state_dict(sd): #load unet in diffusers format +def load_unet_state_dict(sd): #load unet in diffusers or regular format + + #Allow loading unets from checkpoint files + checkpoint = False + diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) + temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) + if len(temp_sd) > 0: + sd = temp_sd + checkpoint = True + parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - if 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 + if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade + model_config = model_detection.model_config_from_unet(sd, "") + if model_config is None: + return None + new_sd = sd + elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 new_sd = model_detection.convert_diffusers_mmdit(sd, "") if new_sd is None: return None model_config = model_detection.model_config_from_unet(new_sd, "") if model_config is None: return None - elif "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade - model_config = model_detection.model_config_from_unet(sd, "") - if model_config is None: - return None - new_sd = sd - else: #diffusers model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None: