Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support IP-Adapter Plus #5915

Merged
merged 18 commits into from
Dec 4, 2023
94 changes: 73 additions & 21 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from collections import OrderedDict, defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union

Expand All @@ -21,7 +21,7 @@
import torch.nn.functional as F
from torch import nn

from ..models.embeddings import ImageProjection
from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -672,6 +672,17 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0,
)

if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]

# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
self.encoder_hid_proj = None

# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
Expand All @@ -695,7 +706,10 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)

value_dict = {}
Expand All @@ -708,26 +722,64 @@ def _load_ip_adapter_weights(self, state_dict):
self.set_attn_processor(attn_procs)

# create image projection layers.
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
)
image_projection.to(dtype=self.dtype, device=self.device)

# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)

image_projection.load_state_dict(image_proj_state_dict)

else:
# IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64

image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)

image_proj_state_dict = state_dict["image_proj"]
new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "norm1" in k:
new_sd[k.replace("norm1", "norm_cross")] = v
elif "norm2" in k:
new_sd[k.replace("norm2", "layer_norm")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v

image_projection.load_state_dict(image_proj_state_dict)
image_projection.load_state_dict(new_sd)
del image_proj_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now, but let's make sure to later factor this out with a conversion_... function later


self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
Expand Down Expand Up @@ -55,6 +56,7 @@
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
Expand Down
48 changes: 48 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class Attention(nn.Module):
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
query_layer_norm (`bool`, defaults to `False`):
Set to `True` to use layer norm for the query.
concat_kv_input (`bool`, defaults to `False`):
Set to `True` to concatenate the hidden_states and encoder_hidden_states for kv inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query_layer_norm (`bool`, defaults to `False`):
Set to `True` to use layer norm for the query.
concat_kv_input (`bool`, defaults to `False`):
Set to `True` to concatenate the hidden_states and encoder_hidden_states for kv inputs.

"""

def __init__(
Expand All @@ -109,6 +113,8 @@ def __init__(
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
query_layer_norm: bool = False,
concat_kv_input: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query_layer_norm: bool = False,
concat_kv_input: bool = False,

):
super().__init__()
self.inner_dim = dim_head * heads
Expand All @@ -118,6 +124,7 @@ def __init__(
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.concat_kv_input = concat_kv_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.concat_kv_input = concat_kv_input


# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
Expand Down Expand Up @@ -150,6 +157,11 @@ def __init__(
else:
self.spatial_norm = None

if query_layer_norm:
self.layer_norm = nn.LayerNorm(query_dim)
else:
self.layer_norm = None

if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
Expand Down Expand Up @@ -726,13 +738,19 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

Expand Down Expand Up @@ -1127,13 +1145,19 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

Expand Down Expand Up @@ -1207,6 +1231,9 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)

Expand All @@ -1215,6 +1242,9 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

Expand Down Expand Up @@ -1517,6 +1547,9 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

can we do this outside of the attention_processor too?

query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
Expand All @@ -1526,6 +1559,9 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if attn.concat_kv_input:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this step can be done outside of the attention processor

encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)


key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
Expand Down Expand Up @@ -2031,13 +2067,19 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
Expand Down Expand Up @@ -2151,13 +2193,19 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
Expand Down
Loading
Loading