-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make LoRACompatibleConv padding_mode work. #6031
Make LoRACompatibleConv padding_mode work. #6031
Conversation
Thanks for your PR. Could you:
Once done, I will run our SLOW tests (such as https://github.com/younesbelkada/diffusers/blob/peft-part-2/tests/lora/test_lora_layers_peft.py) to ensure nothing broke. |
src/diffusers/models/lora.py
Outdated
if self.padding_mode != 'zeros': | ||
return F.conv2d(F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode), | ||
self.weight, self.bias, self.stride, | ||
(0, 0), self.dilation, self.groups) | ||
return F.conv2d(hidden_states, self.weight, self.bias, self.stride, | ||
self.padding, self.dilation, self.groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we compute original_outputs
one time and the reuse it? That way, I think the code will remain cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your guidance.
There is currently no better, cleaner code solution. If you have, please give me some guidance.
testing script
import torch
from diffusers import DiffusionPipeline
# Modify the padding mode of Conv2d
def set_pad_mode(network, mode="circular"):
for _, module in network.named_children():
if len(module._modules) > 0:
set_pad_mode(module, mode)
else:
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
base.to("cuda")
n_steps = 30
prompt = "interior design, Equirectangular, Panoramic, Panorama and 360"
set_pad_mode(base.vae, "circular")
set_pad_mode(base.unet, "circular")
image = base(
prompt=prompt,
height=1024,
width=2048,
num_inference_steps=n_steps,
output_type="pil",
).images[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
What exactly is the use case here that we're trying to solve? Why do we need this padding_mode in the first place? |
src/diffusers/models/lora.py
Outdated
@@ -355,13 +355,34 @@ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tens | |||
if self.lora_layer is None: | |||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break | |||
# see: https://github.com/huggingface/diffusers/pull/4315 | |||
if self.padding_mode != "zeros": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we give this branch a better name? E.g. self.padding_mode == "reversed_adding"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padding mode of torch conv2d only supports 'zeros', 'reflect', 'replicate' or 'circular'.Default: 'zeros'
@patrickvonplaten could you review this once? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we maybe try to simplify the code a bit?
Simplify the code by patrickvonplaten. Co-authored-by: Patrick von Platen <[email protected]>
Thank you for your new implementation ideas.However, if the user switches the padding mode from "circular" to "zeros", problems may occur because self.padding has been set to (0, 0). |
src/diffusers/models/lora.py
Outdated
if self.padding_mode != "zeros": | ||
original_outputs = F.conv2d( | ||
F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode), | ||
self.weight, | ||
self.bias, | ||
self.stride, | ||
(0, 0), | ||
self.dilation, | ||
self.groups, | ||
) | ||
else: | ||
original_outputs = F.conv2d( | ||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.padding_mode != "zeros": | |
original_outputs = F.conv2d( | |
F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode), | |
self.weight, | |
self.bias, | |
self.stride, | |
(0, 0), | |
self.dilation, | |
self.groups, | |
) | |
else: | |
original_outputs = F.conv2d( | |
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
) | |
if self.padding_mode != "zeros": | |
original_outputs = F.conv2d( | |
F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode), | |
self.weight, | |
self.bias, | |
self.stride, | |
(0, 0), | |
self.dilation, | |
self.groups, | |
) | |
else: | |
original_outputs = F.conv2d( | |
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
) |
Can we refactor this part in the same way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten Hi, thank you for your guidance. I have refactored the code, please review it.
src/diffusers/models/lora.py
Outdated
if self.padding_mode != "zeros": | ||
hidden_states_pad = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) | ||
original_outputs = F.conv2d( | ||
hidden_states_pad, | ||
self.weight, | ||
self.bias, | ||
self.stride, | ||
(0, 0), | ||
self.dilation, | ||
self.groups, | ||
) | ||
else: | ||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break | ||
# see: https://github.com/huggingface/diffusers/pull/4315 | ||
original_outputs = F.conv2d( | ||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | ||
) | ||
if self.lora_layer is None: | ||
return original_outputs | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.padding_mode != "zeros": | |
hidden_states_pad = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) | |
original_outputs = F.conv2d( | |
hidden_states_pad, | |
self.weight, | |
self.bias, | |
self.stride, | |
(0, 0), | |
self.dilation, | |
self.groups, | |
) | |
else: | |
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break | |
# see: https://github.com/huggingface/diffusers/pull/4315 | |
original_outputs = F.conv2d( | |
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
) | |
if self.lora_layer is None: | |
return original_outputs | |
else: | |
if self.padding_mode != "zeros": | |
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) | |
padding = (0, 0) | |
else: | |
padding = self.padding | |
original_outputs = F.conv2d( | |
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
) | |
if self.lora_layer is None: | |
return original_outputs | |
else: |
Wouldn't that be easier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or is the padding wrong is this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to your opinion, the code on lines 361-363 should use padding instead of self.padding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you apply @patrickvonplaten 's suggestion here?
It does exactly the same thing but simplifies your code - that's all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,I think his suggestion is great and agree with his modifications, but I think should use
original_outputs = F.conv2d(
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
)
instead of
original_outputs = F.conv2d(
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
Is that so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gentle ping here: @yiyixuxu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @jinghuan-Chen
yes you're right!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
hey @jinghuan-Chen we will merge this once you resolve the merge conflicts and tests pass :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
ohh tests fail - can you look into fixing them? |
What does this PR do?
fix #5957
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.