diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md
index f6ee09f124be..0ef49c3e0ec4 100644
--- a/docs/source/en/api/attnprocessor.md
+++ b/docs/source/en/api/attnprocessor.md
@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0
+## FusedAttnProcessor2_0
+[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
+
## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 40a335527ace..23a3e2bb3791 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -113,12 +113,14 @@ def __init__(
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
+ self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
@@ -180,6 +182,7 @@ def __init__(
else:
linear_cls = LoRACompatibleLinear
+ self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
@@ -692,6 +695,32 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor
return encoder_hidden_states
+ @torch.no_grad()
+ def fuse_projections(self, fuse=True):
+ is_cross_attention = self.cross_attention_dim != self.query_dim
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if not is_cross_attention:
+ # fetch weight matrices.
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ # create a new single projection layer and copy over the weights.
+ self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+
+ self.fused_projections = fuse
+
class AttnProcessor:
r"""
@@ -1184,9 +1213,6 @@ def __call__(
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
-
- args = () if USE_PEFT_BACKEND else (scale,)
-
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1253,6 +1279,103 @@ def __call__(
return hidden_states
+class FusedAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is currently 🧪 experimental in nature and can change in future.
+
+
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+ if encoder_hidden_states is None:
+ qkv = attn.to_qkv(hidden_states, *args)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ query = attn.to_q(hidden_states, *args)
+
+ kv = attn.to_kv(encoder_hidden_states, *args)
+ split_size = kv.shape[-1] // 2
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, *args)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -2251,6 +2374,7 @@ def __call__(
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
+ FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,
diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py
index 464bff9189dd..8fa3574125f9 100644
--- a/src/diffusers/models/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoder_kl.py
@@ -22,6 +22,7 @@
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
+ Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -448,3 +449,41 @@ def forward(
return (dec,)
return DecoderOutput(sample=dec)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
index dd91d8007229..ddf533d3bd3b 100644
--- a/src/diffusers/models/unet_2d_condition.py
+++ b/src/diffusers/models/unet_2d_condition.py
@@ -25,6 +25,7 @@
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
+ Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -794,6 +795,42 @@ def disable_freeu(self):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
def forward(
self,
sample: torch.FloatTensor,
diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
index 6960ba6c4516..761391189f8f 100644
--- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -446,8 +446,9 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
# Relevant to StableDiffusionUpscalePipeline
- if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
- new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
+ if "num_class_embeds" in config:
+ if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index 12d52aa076d4..c8c6247960af 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -34,6 +34,7 @@
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
+ FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -681,7 +682,6 @@ def _get_add_time_ids(
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
@@ -692,6 +692,7 @@ def upcast_vae(self):
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
+ FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
@@ -729,6 +730,65 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_unet = False
+ self.fusing_vae = False
+
+ if unet:
+ self.fusing_unet = True
+ self.unet.fuse_qkv_projections()
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+ if vae:
+ if not isinstance(self.vae, AutoencoderKL):
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+ """
+ if unet:
+ if not self.fusing_unet:
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.unet.unfuse_qkv_projections()
+ self.fusing_unet = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index b14c746f9962..644948ddc0d3 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -24,6 +24,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
+ FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -610,6 +611,7 @@ def upcast_vae(self):
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
+ FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index 0a2f1ca17cb0..8ac63636df86 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -10,10 +10,10 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.activations import get_activation
-from ...models.attention import Attention
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
+ Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
@@ -1000,6 +1000,42 @@ def disable_freeu(self):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
def forward(
self,
sample: torch.FloatTensor,
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index 53dc2ae15432..cef2c4113a48 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -191,10 +191,11 @@ def __init__(
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
+ max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
if self.config.timestep_spacing in ["linspace", "trailing"]:
- return self.sigmas.max()
+ return max_sigma
- return (self.sigmas.max() ** 2 + 1) ** 0.5
+ return (max_sigma**2 + 1) ** 0.5
@property
def step_index(self):
@@ -289,6 +290,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ if sigmas.device.type == "cuda":
+ self.sigmas = self.sigmas.tolist()
self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas):
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index 59f0c0151d3a..280030d94b7c 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -938,6 +938,37 @@ def test_stable_diffusion_xl_save_from_pretrained(self):
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+ def test_stable_diffusion_xl_with_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionXLPipeline(**components)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ sd_pipe.fuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ sd_pipe.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
@slow
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):