Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA serialization] fix: duplicate unet prefix problem. #5991

Merged
merged 3 commits into from
Dec 2, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Nov 30, 2023

What does this PR do?

Fixes: #5977

Trained a LoRA: https://huggingface.co/sayakpaul/corgy_dog_LoRA/tree/main (Colab: https://colab.research.google.com/gist/sayakpaul/bcd42382f111d5c0e3bf8b37b55342ff/sdxl_dreambooth_lora_.ipynb)

from huggingface_hub import hf_hub_download
import safetensors.torch 

repo_id = "sayakpaul/corgy_dog_LoRA"
filename = "pytorch_lora_weights.safetensors"

path = hf_hub_download(repo_id=repo_id, filename=filename)
sd = safetensors.torch.load_file(path)
print(list(sd.keys())[:5])

Doing the above yields:

['unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.lora.down.weight',
 'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.lora.up.weight',
 'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora.down.weight',
 'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora.up.weight',
 'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.lora.down.weight']

Cc: @apolinario

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Comment on lines +812 to +823
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict

unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict)
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")

if text_encoder_lora_layers is not None:
weights = (
text_encoder_lora_layers.state_dict()
if isinstance(text_encoder_lora_layers, torch.nn.Module)
else text_encoder_lora_layers
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))

text_encoder_lora_state_dict = {
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
}
state_dict.update(text_encoder_lora_state_dict)
if text_encoder_lora_layers:
Copy link
Member Author

@sayakpaul sayakpaul Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these changes to keep it clean and unified with respect to how it's done for the SDXL LoRA writer.

@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up".
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the main culprit!

@apolinario
Copy link
Collaborator

Do you think we could merge @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

Thanks a mille for fixing this @sayakpaul !

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to merge once deprecation is updated to 0.27

@sayakpaul sayakpaul merged commit d486f0e into main Dec 2, 2023
22 checks passed
@sayakpaul sayakpaul deleted the fix/lora-serialization-format branch December 2, 2023 16:05
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…#5991)

* fix: duplicate unet prefix problem.

* Update src/diffusers/loaders/lora.py

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…#5991)

* fix: duplicate unet prefix problem.

* Update src/diffusers/loaders/lora.py

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Do not save unet.unet keys when training LoRAs
4 participants