Skip to content

Commit

Permalink
Support SD3 diffusers lora.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 13, 2024
1 parent 37a08a4 commit ac151ac
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def load_lora(lora, to_load):

regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None

Expand All @@ -40,6 +41,10 @@ def load_lora(lora, to_load):
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif diffusers2_lora in lora.keys():
A_name = diffusers2_lora
B_name = "{}.lora_A.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
Expand Down Expand Up @@ -164,6 +169,7 @@ def load_lora(lora, to_load):
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))

return patch_dict

def model_lora_keys_clip(model, key_map={}):
Expand Down Expand Up @@ -217,7 +223,8 @@ def model_lora_keys_clip(model, key_map={}):
return key_map

def model_lora_keys_unet(model, key_map={}):
sdk = model.state_dict().keys()
sd = model.state_dict()
sdk = sd.keys()

for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
Expand All @@ -238,4 +245,17 @@ def model_lora_keys_unet(model, key_map={}):
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key

if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
for i in range(model.model_config.unet_config.get("depth", 0)):
k = "transformer.transformer_blocks.{}.attn.".format(i)
qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i)
proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i)
if qkv in sd:
offset = sd[qkv].shape[0] // 3
key_map["{}to_q".format(k)] = (qkv, (0, 0, offset))
key_map["{}to_k".format(k)] = (qkv, (0, offset, offset))
key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset))
key_map["{}to_out.0".format(k)] = proj

return key_map

0 comments on commit ac151ac

Please sign in to comment.