Skip to content

Commit

Permalink
[LoRA serialization] fix: duplicate unet prefix problem. (huggingface…
Browse files Browse the repository at this point in the history
…#5991)

* fix: duplicate unet prefix problem.

* Update src/diffusers/loaders/lora.py

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
sayakpaul and patrickvonplaten authored Dec 2, 2023
1 parent 3351270 commit d486f0e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
37 changes: 17 additions & 20 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ def load_lora_into_unet(
# their prefixes.
keys = list(state_dict.keys())

if all(key.startswith("unet.unet") for key in keys):
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
deprecate("unet.unet keys", "0.27", deprecation_message)

if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
Expand All @@ -407,8 +411,9 @@ def load_lora_into_unet(
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)

if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
Expand Down Expand Up @@ -800,29 +805,21 @@ def save_lora_weights(
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
# Create a flat dictionary.
state_dict = {}

# Populate the dictionary.
if unet_lora_layers is not None:
weights = (
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
)
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict

unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict)
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")

if text_encoder_lora_layers is not None:
weights = (
text_encoder_lora_layers.state_dict()
if isinstance(text_encoder_lora_layers, torch.nn.Module)
else text_encoder_lora_layers
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))

text_encoder_lora_state_dict = {
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
}
state_dict.update(text_encoder_lora_state_dict)
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))

# Save the model
cls.write_lora_layers(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up".
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param

return lora_state_dict

Expand Down

0 comments on commit d486f0e

Please sign in to comment.