-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support IP-Adapter Plus #5915
Conversation
oh thanks so much @okotaku the first example (sd.15) does not look nice- is the result expected? |
@yiyixuxu I think that the quality of IP-Adapter-Plus-XL is promising, but IP-Adapter-Plus-SDv1.5 is not good. |
@yiyixuxu Other example of IP-Adapter-Plus-SDv1.5.
InputOutput |
the plus model doesn't seem to work....
|
the official example here with for ip-adapter plus face seems ok and it use sd1.5 too - can we try this example? https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter-plus-face_demo.ipynb also can we try a multimodal prompts example? here https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter-plus_demo.ipynb thanks! YiYi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! @okotaku
I left my feedbacks but I will wait @patrickvonplaten to have review before making any changes :)
if isinstance(self.unet.encoder_hid_proj, ImageProjection): | ||
# IP-Adapter | ||
image_embeds = self.image_encoder(image).image_embeds | ||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | ||
uncond_image_embeds = torch.zeros_like(image_embeds) | ||
else: | ||
# IP-Adapter Plus | ||
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | ||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | ||
uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ | ||
-2 | ||
] | ||
uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a argument return_hidden_states
to encode_image
method? I want to keep encode_image
method generic and not specific to ip-adapter.
Also updated the variable names so it is clear whether we are return image embedding or the hidden states
if isinstance(self.unet.encoder_hid_proj, ImageProjection): | |
# IP-Adapter | |
image_embeds = self.image_encoder(image).image_embeds | |
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |
uncond_image_embeds = torch.zeros_like(image_embeds) | |
else: | |
# IP-Adapter Plus | |
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | |
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |
uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ | |
-2 | |
] | |
uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |
if not return_hidden_states: | |
# IP-Adapter | |
image_embeds = self.image_encoder(image).image_embeds | |
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |
uncond_image_embeds = torch.zeros_like(image_embeds) | |
return image_embeds, uncond_image_embeds | |
else: | |
# IP-Adapter Plus | |
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | |
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) | |
uncond_image_enc_hidden_states = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ | |
-2 | |
] | |
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) | |
return image_enc_hidden_states, uncond_image_enc_hidden_states |
src/diffusers/models/embeddings.py
Outdated
@@ -790,3 +790,155 @@ def forward(self, caption, force_drop_ids=None): | |||
hidden_states = self.act_1(hidden_states) | |||
hidden_states = self.linear_2(hidden_states) | |||
return hidden_states | |||
|
|||
|
|||
class PerceiverAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we refactor this using Attention
?
cc @patrickvonplaten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it would be important to use the attention class here IMO to make sure it can be used with torch's scale-dot product attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten @yiyixuxu
When I refactored PerceiverAttention
, I noticed that there are still some differences between out and out2. What are your thoughts on this?
https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py#L72
import torch
import math
import torch.nn.functional as F
scale = 1 / math.sqrt(64)
scale2 = 1 / math.sqrt(math.sqrt(64))
query = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
key = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
value = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
out1 = F.scaled_dot_product_attention(query,key,value, scale=scale)
weight = (query * scale2) @ (key * scale2).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out2 = weight @ value
print(torch.allclose(out1, out2, atol=1e-4))
weight2 = query @ key.transpose(-2, -1) * scale
weight2 = torch.softmax(weight2.float(), dim=-1).type(weight2.dtype)
out3 = weight2 @ value
print(torch.allclose(out1, out3, atol=1e-4))
print(torch.allclose(out2, out3, atol=1e-4))
print(torch.abs(out1 - out2).sum())
print(torch.abs(out3 - out2).sum())
print(torch.abs(out1 - out3).sum())
---
False
False
False
tensor(0.1144, device='cuda:0', dtype=torch.float16)
tensor(0.0535, device='cuda:0', dtype=torch.float16)
tensor(0.1212, device='cuda:0', dtype=torch.float16)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch
import math
import torch.nn.functional as F
scale = 1 / math.sqrt(64)
scale2 = 1 / math.sqrt(math.sqrt(64))
query = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
key = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
value = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
out1 = F.scaled_dot_product_attention(query,key,value, scale=scale)
weight = (query * scale2) @ (key * scale2).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out2 = weight @ value
print(torch.allclose(out1, out2, atol=1e-4))
weight2 = query @ key.transpose(-2, -1) * scale
weight2 = torch.softmax(weight2.float(), dim=-1).type(weight2.dtype)
out3 = weight2 @ value
print(torch.allclose(out1, out3, atol=1e-4))
print(torch.allclose(out2, out3, atol=1e-4))
print(torch.abs(out1 - out2).sum())
print(torch.abs(out3 - out2).sum())
print(torch.abs(out1 - out3).sum())
---
True
True
True
tensor(0.0001, device='cuda:0')
tensor(4.6417e-05, device='cuda:0')
tensor(0.0001, device='cuda:0')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numerical difference for float32 is small, no? that means your implementation is most likely correct
for float16 can we try to see if the difference is less than 1e-3?
Also, let's generate some outputs with the refactored code? if the results look similar to before it should be fine!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related to #5911 |
InputOutput |
InputOutput |
Could this be merged? Looking forward to use this |
@blx0102 me too. I am checking every day for this haha |
Results look amazing here @okotaku! Would be great if we could try to re-use our fast |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@@ -496,11 +496,22 @@ def encode_image(self, image, device, num_images_per_prompt): | |||
image = self.feature_extractor(image, return_tensors="pt").pixel_values | |||
|
|||
image = image.to(device=device, dtype=dtype) | |||
image_embeds = self.image_encoder(image).image_embeds | |||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |||
if self.image_encoder.config.output_hidden_states: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of changing the config for image_encoder let's:
- add a argument
output_hidden_state
toencode_image()
- inside the ip-adaper specific code in the pipelines (i.e. here
if ip_adapter_image is not None: output_hidden_state
and pass it toencode_image()
, e.g.output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
essentially we just want to keep encode_image
more generic and put ip-adapter specific log into the if ip_adapter_image is not None: ....
@@ -1526,6 +1559,9 @@ def __call__( | |||
elif attn.norm_cross: | |||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |||
|
|||
if attn.concat_kv_input: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this step can be done outside of the attention processor
if attn.layer_norm is not None: | ||
hidden_states = attn.layer_norm(hidden_states) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if attn.layer_norm is not None: | |
hidden_states = attn.layer_norm(hidden_states) |
can we do this outside of the attention_processor too?
if attn.concat_kv_input: | ||
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if attn.concat_kv_input: | |
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) |
I run it, but no changes were made.
|
cc @DN6 @patrickvonplaten |
src/diffusers/models/embeddings.py
Outdated
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: | ||
"""Get feedforward network.""" | ||
inner_dim = int(embed_dims * ffn_ratio) | ||
return nn.Sequential( | ||
nn.LayerNorm(embed_dims), | ||
nn.Linear(embed_dims, inner_dim, bias=False), | ||
nn.GELU(), | ||
nn.Linear(inner_dim, embed_dims, bias=False), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: | |
"""Get feedforward network.""" | |
inner_dim = int(embed_dims * ffn_ratio) | |
return nn.Sequential( | |
nn.LayerNorm(embed_dims), | |
nn.Linear(embed_dims, inner_dim, bias=False), | |
nn.GELU(), | |
nn.Linear(inner_dim, embed_dims, bias=False), | |
) | |
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: | |
"""Get feedforward network.""" | |
inner_dim = int(embed_dims * ffn_ratio) | |
return nn.Sequential( | |
nn.LayerNorm(embed_dims), | |
nn.Linear(embed_dims, inner_dim, bias=False), | |
nn.GELU(), | |
nn.Linear(inner_dim, embed_dims, bias=False), | |
) |
Can we replace this with a layer norm layer
first and then
diffusers/src/diffusers/models/attention.py
Line 493 in 7d4a257
class FeedForward(nn.Module): |
@@ -489,18 +489,29 @@ def encode_prompt( | |||
|
|||
return prompt_embeds, negative_prompt_embeds | |||
|
|||
def encode_image(self, image, device, num_images_per_prompt): | |||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): | |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): |
This is public API so we need to make sure it's backward compatible. Also ideally we should add docstrings here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking almost ready to be merged just some minor comments. Let's try to re-use our Feed-forward class here :-)
|
||
image_projection.load_state_dict(image_proj_state_dict) | ||
image_projection.load_state_dict(new_sd) | ||
del image_proj_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for now, but let's make sure to later factor this out with a conversion_...
function later
nn.Sequential( | ||
nn.LayerNorm(hidden_dims), | ||
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.Sequential( | |
nn.LayerNorm(hidden_dims), | |
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), | |
), | |
nn.Sequential( | |
nn.LayerNorm(hidden_dims), | |
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), | |
), |
nice!
Great job @okotaku - super cool addition :-) |
@patrickvonplaten @yiyixuxu Thank you for your reviews! |
Hey, does it work with ip adapter plus yet?
|
@vladmandic I think the face models need to use ddim see here #5911 (comment) |
cc @xiaohu2015 is this expected? |
hi , for face model, you should use a crop of face image. |
It seems controlnet img2img does not support ip adapter plus. @yiyixuxu |
* Support IP-Adapter Plus * fix format * restore before black format * restore before black format * generic * Refactor PerceiverAttention * format * fix test and refactor PerceiverAttention * generic encode_image * keep attention implementation * merge tests * encode_image backward compatible * code quality * fix controlnet inpaint pipeline * refactor FFN * refactor FFN --------- Co-authored-by: YiYi Xu <[email protected]>
* Support IP-Adapter Plus * fix format * restore before black format * restore before black format * generic * Refactor PerceiverAttention * format * fix test and refactor PerceiverAttention * generic encode_image * keep attention implementation * merge tests * encode_image backward compatible * code quality * fix controlnet inpaint pipeline * refactor FFN * refactor FFN --------- Co-authored-by: YiYi Xu <[email protected]>
What does this PR do?
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.