Skip to content

Commit

Permalink
Merge branch 'main' into fix_text_to_image_lora_device
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamadZeina authored Dec 6, 2023
2 parents ab365a2 + f90a513 commit 56f8df2
Show file tree
Hide file tree
Showing 10 changed files with 345 additions and 9 deletions.
3 changes: 3 additions & 0 deletions docs/source/en/api/attnprocessor.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0

## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0

## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor

Expand Down
130 changes: 127 additions & 3 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ def __init__(
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim

# we make use of this private variable to know whether this class is loaded
Expand Down Expand Up @@ -180,6 +182,7 @@ def __init__(
else:
linear_cls = LoRACompatibleLinear

self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)

if not self.only_cross_attention:
Expand Down Expand Up @@ -692,6 +695,32 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor

return encoder_hidden_states

@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype

if not is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)

else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)

self.fused_projections = fuse


class AttnProcessor:
r"""
Expand Down Expand Up @@ -1184,9 +1213,6 @@ def __call__(
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -1253,6 +1279,103 @@ def __call__(
return hidden_states


class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is currently 🧪 experimental in nature and can change in future.
</Tip>
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args)

kv = attn.to_kv(encoder_hidden_states, *args)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
Expand Down Expand Up @@ -2251,6 +2374,7 @@ def __call__(
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,
Expand Down
39 changes: 39 additions & 0 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
Expand Down Expand Up @@ -448,3 +449,41 @@ def forward(
return (dec,)

return DecoderOutput(sample=dec)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None

for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
37 changes: 37 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
Expand Down Expand Up @@ -794,6 +795,42 @@ def disable_freeu(self):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)

def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None

for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,9 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]

# Relevant to StableDiffusionUpscalePipeline
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
if "num_class_embeds" in config:
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]

new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
Expand Down Expand Up @@ -681,7 +682,6 @@ def _get_add_time_ids(
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
Expand All @@ -692,6 +692,7 @@ def upcast_vae(self):
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
Expand Down Expand Up @@ -729,6 +730,65 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()

def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False

if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())

if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")

self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())

def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False

if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False

# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
Expand Down
Loading

0 comments on commit 56f8df2

Please sign in to comment.