-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA serialization] fix: duplicate unet prefix problem. #5991
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
def pack_weights(layers, prefix): | ||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers | ||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} | ||
return layers_state_dict | ||
|
||
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()} | ||
state_dict.update(unet_lora_state_dict) | ||
if not (unet_lora_layers or text_encoder_lora_layers): | ||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") | ||
|
||
if text_encoder_lora_layers is not None: | ||
weights = ( | ||
text_encoder_lora_layers.state_dict() | ||
if isinstance(text_encoder_lora_layers, torch.nn.Module) | ||
else text_encoder_lora_layers | ||
) | ||
if unet_lora_layers: | ||
state_dict.update(pack_weights(unet_lora_layers, "unet")) | ||
|
||
text_encoder_lora_state_dict = { | ||
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items() | ||
} | ||
state_dict.update(text_encoder_lora_state_dict) | ||
if text_encoder_lora_layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All these changes to keep it clean and unified with respect to how it's done for the SDXL LoRA writer.
@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: | |||
current_lora_layer_sd = lora_layer.state_dict() | |||
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): | |||
# The matrix name can either be "down" or "up". | |||
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param | |||
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the main culprit!
Do you think we could merge @patrickvonplaten ? |
Thanks a mille for fixing this @sayakpaul ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to merge once deprecation is updated to 0.27
Co-authored-by: Patrick von Platen <[email protected]>
…#5991) * fix: duplicate unet prefix problem. * Update src/diffusers/loaders/lora.py Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
…#5991) * fix: duplicate unet prefix problem. * Update src/diffusers/loaders/lora.py Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
What does this PR do?
Fixes: #5977
Trained a LoRA: https://huggingface.co/sayakpaul/corgy_dog_LoRA/tree/main (Colab: https://colab.research.google.com/gist/sayakpaul/bcd42382f111d5c0e3bf8b37b55342ff/sdxl_dreambooth_lora_.ipynb)
Doing the above yields:
Cc: @apolinario