Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support IP-Adapter Plus #5915

Merged
merged 18 commits into from
Dec 4, 2023
106 changes: 85 additions & 21 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from collections import OrderedDict, defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union

@@ -21,7 +21,7 @@
import torch.nn.functional as F
from torch import nn

from ..models.embeddings import ImageProjection
from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
@@ -672,6 +672,17 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0,
)

if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]

# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
self.encoder_hid_proj = None

# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
@@ -695,7 +706,10 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)

value_dict = {}
@@ -708,26 +722,76 @@ def _load_ip_adapter_weights(self, state_dict):
self.set_attn_processor(attn_procs)

# create image projection layers.
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
)
image_projection.to(dtype=self.dtype, device=self.device)

# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)

image_projection.load_state_dict(image_proj_state_dict)

else:
# IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64

image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)

image_proj_state_dict = state_dict["image_proj"]

new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "0.to" in k:
k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
k = k.replace("1.3.weight", "3.1.net.2.weight")

if "norm1" in k:
new_sd[k.replace("0.norm1", "0")] = v
elif "norm2" in k:
new_sd[k.replace("0.norm2", "1")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v

image_projection.load_state_dict(image_proj_state_dict)
image_projection.load_state_dict(new_sd)
del image_proj_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now, but let's make sure to later factor this out with a conversion_... function later


self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
@@ -63,6 +64,7 @@
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
15 changes: 9 additions & 6 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
@@ -55,11 +55,12 @@ class GELU(nn.Module):
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""

def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
@@ -81,13 +82,14 @@ class GEGLU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""

def __init__(self, dim_in: int, dim_out: int):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

self.proj = linear_cls(dim_in, dim_out * 2)
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
@@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""

def __init__(self, dim_in: int, dim_out: int):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.proj = nn.Linear(dim_in, dim_out, bias=bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
12 changes: 7 additions & 5 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
@@ -501,6 +501,7 @@ class FeedForward(nn.Module):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""

def __init__(
@@ -511,28 +512,29 @@ def __init__(
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
bias: bool = True,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)

self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out))
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
89 changes: 89 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@

from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .attention_processor import Attention
from .lora import LoRACompatibleLinear


@@ -790,3 +791,91 @@ def forward(self, caption, force_drop_ids=None):
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


class Resampler(nn.Module):
"""Resampler of IP-Adapter Plus.

Args:
----
embed_dims (int): The feature dimension. Defaults to 768.
output_dims (int): The number of output channels, that is the same
number of the channels in the
`unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280.
depth (int): The number of blocks. Defaults to 8.
dim_head (int): The number of head channels. Defaults to 64.
heads (int): Parallel attention heads. Defaults to 16.
num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
"""

def __init__(
self,
embed_dims: int = 768,
output_dims: int = 1024,
hidden_dims: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward # Lazy import to avoid circular import

self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)

self.proj_in = nn.Linear(embed_dims, hidden_dims)

self.proj_out = nn.Linear(hidden_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
nn.LayerNorm(hidden_dims),
nn.LayerNorm(hidden_dims),
Attention(
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
query_dim=hidden_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
),
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
Comment on lines +848 to +851
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),

nice!

]
)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.

Args:
----
x (torch.Tensor): Input Tensor.

Returns:
-------
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)

x = self.proj_in(x)

for ln0, ln1, attn, ff in self.layers:
residual = latents

encoder_hidden_states = ln0(x)
latents = ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = attn(latents, encoder_hidden_states) + residual
latents = ff(latents) + latents

latents = self.proj_out(latents)
return self.norm_out(latents)
Loading