Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support IP-Adapter Plus #5915

Merged
merged 18 commits into from
Dec 4, 2023
72 changes: 52 additions & 20 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.nn.functional as F
from torch import nn

from ..models.embeddings import ImageProjection
from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -672,6 +672,13 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0,
)

if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]

# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
Expand All @@ -695,7 +702,10 @@ def _load_ip_adapter_weights(self, state_dict):
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)

value_dict = {}
Expand All @@ -708,26 +718,48 @@ def _load_ip_adapter_weights(self, state_dict):
self.set_attn_processor(attn_procs)

# create image projection layers.
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
)
image_projection.to(dtype=self.dtype, device=self.device)

# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)

image_projection.load_state_dict(image_proj_state_dict)

else:
# IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
num_heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64

image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
num_heads=num_heads,
num_queries=num_image_text_embeds,
)

image_projection.load_state_dict(image_proj_state_dict)
image_proj_state_dict = state_dict["image_proj"]
image_projection.load_state_dict(image_proj_state_dict)

self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
Expand All @@ -54,6 +55,7 @@
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
Expand Down
152 changes: 152 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,3 +790,155 @@ def forward(self, caption, force_drop_ids=None):
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


class PerceiverAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we refactor this using Attention?
cc @patrickvonplaten

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would be important to use the attention class here IMO to make sure it can be used with torch's scale-dot product attention

Copy link
Contributor Author

@okotaku okotaku Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten @yiyixuxu
When I refactored PerceiverAttention, I noticed that there are still some differences between out and out2. What are your thoughts on this?

https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py#L72

import torch
import math
import torch.nn.functional as F

scale = 1 / math.sqrt(64)
scale2 = 1 / math.sqrt(math.sqrt(64))

query = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
key = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
value = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
out1 = F.scaled_dot_product_attention(query,key,value, scale=scale)

weight = (query * scale2) @ (key * scale2).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out2 = weight @ value
print(torch.allclose(out1, out2, atol=1e-4))

weight2 = query @ key.transpose(-2, -1) * scale
weight2 = torch.softmax(weight2.float(), dim=-1).type(weight2.dtype)
out3 = weight2 @ value
print(torch.allclose(out1, out3, atol=1e-4))
print(torch.allclose(out2, out3, atol=1e-4))

print(torch.abs(out1 - out2).sum())
print(torch.abs(out3 - out2).sum())
print(torch.abs(out1 - out3).sum())

---
False
False
False
tensor(0.1144, device='cuda:0', dtype=torch.float16)
tensor(0.0535, device='cuda:0', dtype=torch.float16)
tensor(0.1212, device='cuda:0', dtype=torch.float16)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
import math
import torch.nn.functional as F

scale = 1 / math.sqrt(64)
scale2 = 1 / math.sqrt(math.sqrt(64))

query = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
key = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
value = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
out1 = F.scaled_dot_product_attention(query,key,value, scale=scale)

weight = (query * scale2) @ (key * scale2).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out2 = weight @ value
print(torch.allclose(out1, out2, atol=1e-4))

weight2 = query @ key.transpose(-2, -1) * scale
weight2 = torch.softmax(weight2.float(), dim=-1).type(weight2.dtype)
out3 = weight2 @ value
print(torch.allclose(out1, out3, atol=1e-4))
print(torch.allclose(out2, out3, atol=1e-4))

print(torch.abs(out1 - out2).sum())
print(torch.abs(out3 - out2).sum())
print(torch.abs(out1 - out3).sum())

---
True
True
True
tensor(0.0001, device='cuda:0')
tensor(4.6417e-05, device='cuda:0')
tensor(0.0001, device='cuda:0')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numerical difference for float32 is small, no? that means your implementation is most likely correct
for float16 can we try to see if the difference is less than 1e-3?
Also, let's generate some outputs with the refactored code? if the results look similar to before it should be fine!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output image is no problems.

demo

"""PerceiverAttention of IP-Adapter Plus.

Args:
----
embed_dims (int): The feature dimension.
head_dims (int): The number of head channels. Defaults to 64.
num_heads (int): Parallel attention heads. Defaults to 16.
"""

def __init__(self, embed_dims: int, head_dims=64, num_heads: int = 16) -> None:
super().__init__()
self.head_dims = head_dims
self.num_heads = num_heads
inner_dim = head_dims * num_heads

self.norm1 = nn.LayerNorm(embed_dims)
self.norm2 = nn.LayerNorm(embed_dims)

self.to_q = nn.Linear(embed_dims, inner_dim, bias=False)
self.to_kv = nn.Linear(embed_dims, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, embed_dims, bias=False)

def _reshape_tensor(self, x, heads) -> torch.Tensor:
"""Reshape tensor."""
bs, length, _ = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) -->
# (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) -->
# (bs*n_heads, length, dim_per_head)
return x.reshape(bs, heads, length, -1)

def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
"""Forward pass.

Args:
----
x (torch.Tensor): image features
shape (b, n1, D)
latents (torch.Tensor): latent features
shape (b, n2, D).
"""
x = self.norm1(x)
latents = self.norm2(latents)

b, len_latents, _ = latents.shape

q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)

q = self._reshape_tensor(q, self.num_heads)
k = self._reshape_tensor(k, self.num_heads)
v = self._reshape_tensor(v, self.num_heads)

# attention
scale = 1 / math.sqrt(math.sqrt(self.head_dims))
# More stable with f16 than dividing afterwards
weight = (q * scale) @ (k * scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v

out = out.permute(0, 2, 1, 3).reshape(b, len_latents, -1)

return self.to_out(out)


class Resampler(nn.Module):
"""Resampler of IP-Adapter Plus.

Args:
----
embed_dims (int): The feature dimension. Defaults to 768.
output_dims (int): The number of output channels, that is the same
number of the channels in the
`unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280.
depth (int): The number of blocks. Defaults to 8.
head_dims (int): The number of head channels. Defaults to 64.
num_heads (int): Parallel attention heads. Defaults to 16.
num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
"""

def __init__(
self,
embed_dims: int = 768,
output_dims: int = 1024,
hidden_dims: int = 1280,
depth: int = 4,
head_dims: int = 64,
num_heads: int = 16,
num_queries: int = 8,
ffn_ratio: float = 4,
) -> None:
super().__init__()

self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)

self.proj_in = nn.Linear(embed_dims, hidden_dims)

self.proj_out = nn.Linear(hidden_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(embed_dims=hidden_dims, head_dims=head_dims, num_heads=num_heads),
self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio),
]
)
)

def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential:
"""Get feedforward network."""
inner_dim = int(embed_dims * ffn_ratio)
return nn.Sequential(
nn.LayerNorm(embed_dims),
nn.Linear(embed_dims, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, embed_dims, bias=False),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential:
"""Get feedforward network."""
inner_dim = int(embed_dims * ffn_ratio)
return nn.Sequential(
nn.LayerNorm(embed_dims),
nn.Linear(embed_dims, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, embed_dims, bias=False),
)
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential:
"""Get feedforward network."""
inner_dim = int(embed_dims * ffn_ratio)
return nn.Sequential(
nn.LayerNorm(embed_dims),
nn.Linear(embed_dims, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, embed_dims, bias=False),
)

Can we replace this with a layer norm layer
first and then

class FeedForward(nn.Module):
?


def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.

Args:
----
x (torch.Tensor): Input Tensor.

Returns:
-------
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)

x = self.proj_in(x)

for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents

latents = self.proj_out(latents)
return self.norm_out(latents)
19 changes: 14 additions & 5 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
Expand Down Expand Up @@ -456,10 +456,19 @@ def encode_image(self, image, device, num_images_per_prompt):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
if isinstance(self.unet.encoder_hid_proj, ImageProjection):
# IP-Adapter
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
else:
# IP-Adapter Plus
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds

def run_safety_checker(self, image, device, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
Expand Down Expand Up @@ -463,10 +463,19 @@ def encode_image(self, image, device, num_images_per_prompt):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
if isinstance(self.unet.encoder_hid_proj, ImageProjection):
# IP-Adapter
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
else:
# IP-Adapter Plus
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds

def run_safety_checker(self, image, device, dtype):
Expand Down
19 changes: 14 additions & 5 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
from ...schedulers import (
Expand Down Expand Up @@ -327,10 +327,19 @@ def encode_image(self, image, device, num_images_per_prompt):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
if isinstance(self.unet.encoder_hid_proj, ImageProjection):
# IP-Adapter
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
else:
# IP-Adapter Plus
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
Expand Down
Loading
Loading