Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Fix scale unscale with LoRA adapters #5417

Merged
merged 13 commits into from
Oct 21, 2023
2 changes: 1 addition & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ def forward(

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
unscale_lora_layers(self, scale=lora_scale)

if not return_dict:
return (sample,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def encode_prompt(

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder, lora_scale)
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,8 @@ def encode_prompt(

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder, lora_scale)
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ def encode_prompt(

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder, lora_scale)
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def encode_prompt(

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder, lora_scale)
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ def encode_prompt(

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder, lora_scale)
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,8 @@ def encode_prompt(

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder, lora_scale)
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ def encode_prompt(

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder, lora_scale)
unscale_lora_layers(self.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def encode_prompt(

if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, negative_prompt_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import apply_freeu


Expand Down Expand Up @@ -1338,7 +1338,7 @@ def forward(

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
unscale_lora_layers(self, scale=lora_scale)

if not return_dict:
return (sample,)
Expand Down
Loading