-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support IP-Adapter Plus #5915
Changes from 9 commits
5e6e351
d39d763
07dbb43
9096c37
9893127
21e6275
87aa9d5
775487f
69ea705
59012a6
90096f2
fe7d232
bc9053d
11cacbd
8cb48a5
24ed74a
e27b641
390f3c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -84,6 +84,10 @@ class Attention(nn.Module): | |||||||||
processor (`AttnProcessor`, *optional*, defaults to `None`): | ||||||||||
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and | ||||||||||
`AttnProcessor` otherwise. | ||||||||||
query_layer_norm (`bool`, defaults to `False`): | ||||||||||
Set to `True` to use layer norm for the query. | ||||||||||
concat_kv_input (`bool`, defaults to `False`): | ||||||||||
Set to `True` to concatenate the hidden_states and encoder_hidden_states for kv inputs. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
""" | ||||||||||
|
||||||||||
def __init__( | ||||||||||
|
@@ -109,6 +113,8 @@ def __init__( | |||||||||
residual_connection: bool = False, | ||||||||||
_from_deprecated_attn_block: bool = False, | ||||||||||
processor: Optional["AttnProcessor"] = None, | ||||||||||
query_layer_norm: bool = False, | ||||||||||
concat_kv_input: bool = False, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
): | ||||||||||
super().__init__() | ||||||||||
self.inner_dim = dim_head * heads | ||||||||||
|
@@ -118,6 +124,7 @@ def __init__( | |||||||||
self.rescale_output_factor = rescale_output_factor | ||||||||||
self.residual_connection = residual_connection | ||||||||||
self.dropout = dropout | ||||||||||
self.concat_kv_input = concat_kv_input | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
# we make use of this private variable to know whether this class is loaded | ||||||||||
# with an deprecated state dict so that we can convert it on the fly | ||||||||||
|
@@ -150,6 +157,11 @@ def __init__( | |||||||||
else: | ||||||||||
self.spatial_norm = None | ||||||||||
|
||||||||||
if query_layer_norm: | ||||||||||
self.layer_norm = nn.LayerNorm(query_dim) | ||||||||||
else: | ||||||||||
self.layer_norm = None | ||||||||||
|
||||||||||
if cross_attention_norm is None: | ||||||||||
self.norm_cross = None | ||||||||||
elif cross_attention_norm == "layer_norm": | ||||||||||
|
@@ -726,13 +738,19 @@ def __call__( | |||||||||
if attn.group_norm is not None: | ||||||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||||||||||
|
||||||||||
if attn.layer_norm is not None: | ||||||||||
hidden_states = attn.layer_norm(hidden_states) | ||||||||||
|
||||||||||
query = attn.to_q(hidden_states, *args) | ||||||||||
|
||||||||||
if encoder_hidden_states is None: | ||||||||||
encoder_hidden_states = hidden_states | ||||||||||
elif attn.norm_cross: | ||||||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||||||||||
|
||||||||||
if attn.concat_kv_input: | ||||||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) | ||||||||||
|
||||||||||
key = attn.to_k(encoder_hidden_states, *args) | ||||||||||
value = attn.to_v(encoder_hidden_states, *args) | ||||||||||
|
||||||||||
|
@@ -1127,13 +1145,19 @@ def __call__( | |||||||||
if attn.group_norm is not None: | ||||||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||||||||||
|
||||||||||
if attn.layer_norm is not None: | ||||||||||
hidden_states = attn.layer_norm(hidden_states) | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
query = attn.to_q(hidden_states, *args) | ||||||||||
|
||||||||||
if encoder_hidden_states is None: | ||||||||||
encoder_hidden_states = hidden_states | ||||||||||
elif attn.norm_cross: | ||||||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||||||||||
|
||||||||||
if attn.concat_kv_input: | ||||||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
key = attn.to_k(encoder_hidden_states, *args) | ||||||||||
value = attn.to_v(encoder_hidden_states, *args) | ||||||||||
|
||||||||||
|
@@ -1207,6 +1231,9 @@ def __call__( | |||||||||
if attn.group_norm is not None: | ||||||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||||||||||
|
||||||||||
if attn.layer_norm is not None: | ||||||||||
hidden_states = attn.layer_norm(hidden_states) | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
args = () if USE_PEFT_BACKEND else (scale,) | ||||||||||
query = attn.to_q(hidden_states, *args) | ||||||||||
|
||||||||||
|
@@ -1215,6 +1242,9 @@ def __call__( | |||||||||
elif attn.norm_cross: | ||||||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||||||||||
|
||||||||||
if attn.concat_kv_input: | ||||||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
key = attn.to_k(encoder_hidden_states, *args) | ||||||||||
value = attn.to_v(encoder_hidden_states, *args) | ||||||||||
|
||||||||||
|
@@ -1517,6 +1547,9 @@ def __call__( | |||||||||
if attn.group_norm is not None: | ||||||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||||||||||
|
||||||||||
if attn.layer_norm is not None: | ||||||||||
hidden_states = attn.layer_norm(hidden_states) | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
can we do this outside of the attention_processor too? |
||||||||||
query = attn.to_q(hidden_states) | ||||||||||
dim = query.shape[-1] | ||||||||||
query = attn.head_to_batch_dim(query) | ||||||||||
|
@@ -1526,6 +1559,9 @@ def __call__( | |||||||||
elif attn.norm_cross: | ||||||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||||||||||
|
||||||||||
if attn.concat_kv_input: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this step can be done outside of the attention processor |
||||||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
key = attn.to_k(encoder_hidden_states) | ||||||||||
value = attn.to_v(encoder_hidden_states) | ||||||||||
key = attn.head_to_batch_dim(key) | ||||||||||
|
@@ -2031,13 +2067,19 @@ def __call__( | |||||||||
if attn.group_norm is not None: | ||||||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||||||||||
|
||||||||||
if attn.layer_norm is not None: | ||||||||||
hidden_states = attn.layer_norm(hidden_states) | ||||||||||
|
||||||||||
query = attn.to_q(hidden_states) | ||||||||||
|
||||||||||
if encoder_hidden_states is None: | ||||||||||
encoder_hidden_states = hidden_states | ||||||||||
elif attn.norm_cross: | ||||||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||||||||||
|
||||||||||
if attn.concat_kv_input: | ||||||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) | ||||||||||
|
||||||||||
# split hidden states | ||||||||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens | ||||||||||
encoder_hidden_states, ip_hidden_states = ( | ||||||||||
|
@@ -2151,13 +2193,19 @@ def __call__( | |||||||||
if attn.group_norm is not None: | ||||||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||||||||||
|
||||||||||
if attn.layer_norm is not None: | ||||||||||
hidden_states = attn.layer_norm(hidden_states) | ||||||||||
|
||||||||||
query = attn.to_q(hidden_states) | ||||||||||
|
||||||||||
if encoder_hidden_states is None: | ||||||||||
encoder_hidden_states = hidden_states | ||||||||||
elif attn.norm_cross: | ||||||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||||||||||
|
||||||||||
if attn.concat_kv_input: | ||||||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) | ||||||||||
|
||||||||||
# split hidden states | ||||||||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens | ||||||||||
encoder_hidden_states, ip_hidden_states = ( | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for now, but let's make sure to later factor this out with a
conversion_...
function later