-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support IP-Adapter Plus #5915
Merged
+683
−123
Merged
Changes from 4 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
5e6e351
Support IP-Adapter Plus
okotaku d39d763
fix format
okotaku 07dbb43
restore before black format
okotaku 9096c37
restore before black format
okotaku 9893127
Merge branch 'main' into ip_adapter_plus
yiyixuxu 21e6275
generic
okotaku 87aa9d5
Refactor PerceiverAttention
okotaku 775487f
format
okotaku 69ea705
fix test and refactor PerceiverAttention
okotaku 59012a6
generic encode_image
okotaku 90096f2
keep attention implementation
okotaku fe7d232
merge tests
okotaku bc9053d
Merge branch 'main' into ip_adapter_plus
yiyixuxu 11cacbd
encode_image backward compatible
okotaku 8cb48a5
code quality
okotaku 24ed74a
fix controlnet inpaint pipeline
okotaku e27b641
refactor FFN
okotaku 390f3c0
refactor FFN
okotaku File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -790,3 +790,155 @@ def forward(self, caption, force_drop_ids=None): | |||||||||||||||||||||||||||||||||||||||
hidden_states = self.act_1(hidden_states) | ||||||||||||||||||||||||||||||||||||||||
hidden_states = self.linear_2(hidden_states) | ||||||||||||||||||||||||||||||||||||||||
return hidden_states | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
class PerceiverAttention(nn.Module): | ||||||||||||||||||||||||||||||||||||||||
"""PerceiverAttention of IP-Adapter Plus. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||
---- | ||||||||||||||||||||||||||||||||||||||||
embed_dims (int): The feature dimension. | ||||||||||||||||||||||||||||||||||||||||
head_dims (int): The number of head channels. Defaults to 64. | ||||||||||||||||||||||||||||||||||||||||
num_heads (int): Parallel attention heads. Defaults to 16. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def __init__(self, embed_dims: int, head_dims=64, num_heads: int = 16) -> None: | ||||||||||||||||||||||||||||||||||||||||
super().__init__() | ||||||||||||||||||||||||||||||||||||||||
self.head_dims = head_dims | ||||||||||||||||||||||||||||||||||||||||
self.num_heads = num_heads | ||||||||||||||||||||||||||||||||||||||||
inner_dim = head_dims * num_heads | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
self.norm1 = nn.LayerNorm(embed_dims) | ||||||||||||||||||||||||||||||||||||||||
self.norm2 = nn.LayerNorm(embed_dims) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
self.to_q = nn.Linear(embed_dims, inner_dim, bias=False) | ||||||||||||||||||||||||||||||||||||||||
self.to_kv = nn.Linear(embed_dims, inner_dim * 2, bias=False) | ||||||||||||||||||||||||||||||||||||||||
self.to_out = nn.Linear(inner_dim, embed_dims, bias=False) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def _reshape_tensor(self, x, heads) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||
"""Reshape tensor.""" | ||||||||||||||||||||||||||||||||||||||||
bs, length, _ = x.shape | ||||||||||||||||||||||||||||||||||||||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head) | ||||||||||||||||||||||||||||||||||||||||
x = x.view(bs, length, heads, -1) | ||||||||||||||||||||||||||||||||||||||||
# (bs, length, n_heads, dim_per_head) --> | ||||||||||||||||||||||||||||||||||||||||
# (bs, n_heads, length, dim_per_head) | ||||||||||||||||||||||||||||||||||||||||
x = x.transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||||
# (bs, n_heads, length, dim_per_head) --> | ||||||||||||||||||||||||||||||||||||||||
# (bs*n_heads, length, dim_per_head) | ||||||||||||||||||||||||||||||||||||||||
return x.reshape(bs, heads, length, -1) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||
"""Forward pass. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||
---- | ||||||||||||||||||||||||||||||||||||||||
x (torch.Tensor): image features | ||||||||||||||||||||||||||||||||||||||||
shape (b, n1, D) | ||||||||||||||||||||||||||||||||||||||||
latents (torch.Tensor): latent features | ||||||||||||||||||||||||||||||||||||||||
shape (b, n2, D). | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
x = self.norm1(x) | ||||||||||||||||||||||||||||||||||||||||
latents = self.norm2(latents) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
b, len_latents, _ = latents.shape | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
q = self.to_q(latents) | ||||||||||||||||||||||||||||||||||||||||
kv_input = torch.cat((x, latents), dim=-2) | ||||||||||||||||||||||||||||||||||||||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
q = self._reshape_tensor(q, self.num_heads) | ||||||||||||||||||||||||||||||||||||||||
k = self._reshape_tensor(k, self.num_heads) | ||||||||||||||||||||||||||||||||||||||||
v = self._reshape_tensor(v, self.num_heads) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# attention | ||||||||||||||||||||||||||||||||||||||||
scale = 1 / math.sqrt(math.sqrt(self.head_dims)) | ||||||||||||||||||||||||||||||||||||||||
# More stable with f16 than dividing afterwards | ||||||||||||||||||||||||||||||||||||||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) | ||||||||||||||||||||||||||||||||||||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | ||||||||||||||||||||||||||||||||||||||||
out = weight @ v | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
out = out.permute(0, 2, 1, 3).reshape(b, len_latents, -1) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
return self.to_out(out) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
class Resampler(nn.Module): | ||||||||||||||||||||||||||||||||||||||||
"""Resampler of IP-Adapter Plus. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||
---- | ||||||||||||||||||||||||||||||||||||||||
embed_dims (int): The feature dimension. Defaults to 768. | ||||||||||||||||||||||||||||||||||||||||
output_dims (int): The number of output channels, that is the same | ||||||||||||||||||||||||||||||||||||||||
number of the channels in the | ||||||||||||||||||||||||||||||||||||||||
`unet.config.cross_attention_dim`. Defaults to 1024. | ||||||||||||||||||||||||||||||||||||||||
hidden_dims (int): The number of hidden channels. Defaults to 1280. | ||||||||||||||||||||||||||||||||||||||||
depth (int): The number of blocks. Defaults to 8. | ||||||||||||||||||||||||||||||||||||||||
head_dims (int): The number of head channels. Defaults to 64. | ||||||||||||||||||||||||||||||||||||||||
num_heads (int): Parallel attention heads. Defaults to 16. | ||||||||||||||||||||||||||||||||||||||||
num_queries (int): The number of queries. Defaults to 8. | ||||||||||||||||||||||||||||||||||||||||
ffn_ratio (float): The expansion ratio of feedforward network hidden | ||||||||||||||||||||||||||||||||||||||||
layer channels. Defaults to 4. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||
embed_dims: int = 768, | ||||||||||||||||||||||||||||||||||||||||
output_dims: int = 1024, | ||||||||||||||||||||||||||||||||||||||||
hidden_dims: int = 1280, | ||||||||||||||||||||||||||||||||||||||||
depth: int = 4, | ||||||||||||||||||||||||||||||||||||||||
head_dims: int = 64, | ||||||||||||||||||||||||||||||||||||||||
num_heads: int = 16, | ||||||||||||||||||||||||||||||||||||||||
num_queries: int = 8, | ||||||||||||||||||||||||||||||||||||||||
ffn_ratio: float = 4, | ||||||||||||||||||||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||||||||||||||||||||
super().__init__() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
self.proj_in = nn.Linear(embed_dims, hidden_dims) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
self.proj_out = nn.Linear(hidden_dims, output_dims) | ||||||||||||||||||||||||||||||||||||||||
self.norm_out = nn.LayerNorm(output_dims) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
self.layers = nn.ModuleList([]) | ||||||||||||||||||||||||||||||||||||||||
for _ in range(depth): | ||||||||||||||||||||||||||||||||||||||||
self.layers.append( | ||||||||||||||||||||||||||||||||||||||||
nn.ModuleList( | ||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||
PerceiverAttention(embed_dims=hidden_dims, head_dims=head_dims, num_heads=num_heads), | ||||||||||||||||||||||||||||||||||||||||
self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio), | ||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: | ||||||||||||||||||||||||||||||||||||||||
"""Get feedforward network.""" | ||||||||||||||||||||||||||||||||||||||||
inner_dim = int(embed_dims * ffn_ratio) | ||||||||||||||||||||||||||||||||||||||||
return nn.Sequential( | ||||||||||||||||||||||||||||||||||||||||
nn.LayerNorm(embed_dims), | ||||||||||||||||||||||||||||||||||||||||
nn.Linear(embed_dims, inner_dim, bias=False), | ||||||||||||||||||||||||||||||||||||||||
nn.GELU(), | ||||||||||||||||||||||||||||||||||||||||
nn.Linear(inner_dim, embed_dims, bias=False), | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Can we replace this with a layer norm layer diffusers/src/diffusers/models/attention.py Line 493 in 7d4a257
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||
"""Forward pass. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||
---- | ||||||||||||||||||||||||||||||||||||||||
x (torch.Tensor): Input Tensor. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||
------- | ||||||||||||||||||||||||||||||||||||||||
torch.Tensor: Output Tensor. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
latents = self.latents.repeat(x.size(0), 1, 1) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
x = self.proj_in(x) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
for attn, ff in self.layers: | ||||||||||||||||||||||||||||||||||||||||
latents = attn(x, latents) + latents | ||||||||||||||||||||||||||||||||||||||||
latents = ff(latents) + latents | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
latents = self.proj_out(latents) | ||||||||||||||||||||||||||||||||||||||||
return self.norm_out(latents) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we refactor this using
Attention
?cc @patrickvonplaten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it would be important to use the attention class here IMO to make sure it can be used with torch's scale-dot product attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten @yiyixuxu
When I refactored
PerceiverAttention
, I noticed that there are still some differences between out and out2. What are your thoughts on this?https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py#L72
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numerical difference for float32 is small, no? that means your implementation is most likely correct
for float16 can we try to see if the difference is less than 1e-3?
Also, let's generate some outputs with the refactored code? if the results look similar to before it should be fine!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Output image is no problems.