From a2bc2e14b93b584bea6ba17bdaf7486c6ab808e2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 6 Dec 2023 07:33:26 +0530 Subject: [PATCH 1/2] [feat] allow SDXL pipeline to run with fused QKV projections (#6030) * debug * from step * print * turn sigma a list * make str * init_noise_sigma * comment * remove prints * feat: introduce fused projections * change to a better name * no grad * device. * device * dtype * okay * print * more print * fix: unbind -> split * fix: qkv >-> k * enable disable * apply attention processor within the method * attn processors * _enable_fused_qkv_projections * remove print * add fused projection to vae * add todos. * add: documentation and cleanups. * add: test for qkv projection fusion. * relax assertions. * relax further * fix: docs * fix-copies * correct error message. * Empty-Commit * better conditioning on disable_fused_qkv_projections * check * check processor * bfloat16 computation. * check latent dtype * style * remove copy temporarily * cast latent to bfloat16 * fix: vae -> self.vae * remove print. * add _change_to_group_norm_32 * comment out stuff that didn't work * Apply suggestions from code review Co-authored-by: Patrick von Platen * reflect patrick's suggestions. * fix imports * fix: disable call. * fix more * fix device and dtype * fix conditions. * fix more * Apply suggestions from code review Co-authored-by: Patrick von Platen --------- Co-authored-by: Patrick von Platen --- docs/source/en/api/attnprocessor.md | 3 + src/diffusers/models/attention_processor.py | 130 +++++++++++++++++- src/diffusers/models/autoencoder_kl.py | 39 ++++++ src/diffusers/models/unet_2d_condition.py | 37 +++++ .../pipeline_stable_diffusion_xl.py | 62 ++++++++- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 2 + .../versatile_diffusion/modeling_text_unet.py | 38 ++++- .../schedulers/scheduling_euler_discrete.py | 7 +- .../test_stable_diffusion_xl.py | 31 +++++ 9 files changed, 342 insertions(+), 7 deletions(-) diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index f6ee09f124be..0ef49c3e0ec4 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech ## AttnProcessor2_0 [[autodoc]] models.attention_processor.AttnProcessor2_0 +## FusedAttnProcessor2_0 +[[autodoc]] models.attention_processor.FusedAttnProcessor2_0 + ## LoRAAttnProcessor [[autodoc]] models.attention_processor.LoRAAttnProcessor diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 40a335527ace..23a3e2bb3791 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -113,12 +113,14 @@ def __init__( ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout + self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim # we make use of this private variable to know whether this class is loaded @@ -180,6 +182,7 @@ def __init__( else: linear_cls = LoRACompatibleLinear + self.linear_cls = linear_cls self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: @@ -692,6 +695,32 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor return encoder_hidden_states + @torch.no_grad() + def fuse_projections(self, fuse=True): + is_cross_attention = self.cross_attention_dim != self.query_dim + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + + self.fused_projections = fuse + class AttnProcessor: r""" @@ -1184,9 +1213,6 @@ def __call__( scale: float = 1.0, ) -> torch.FloatTensor: residual = hidden_states - - args = () if USE_PEFT_BACKEND else (scale,) - if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1253,6 +1279,103 @@ def __call__( return hidden_states +class FusedAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is currently 🧪 experimental in nature and can change in future. + + + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states, *args) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states, *args) + + kv = attn.to_kv(encoder_hidden_states, *args) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class CustomDiffusionXFormersAttnProcessor(nn.Module): r""" Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. @@ -2251,6 +2374,7 @@ def __call__( AttentionProcessor = Union[ AttnProcessor, AttnProcessor2_0, + FusedAttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor, AttnAddedKVProcessor, diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 464bff9189dd..8fa3574125f9 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -22,6 +22,7 @@ from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, @@ -448,3 +449,41 @@ def forward( return (dec,) return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index dd91d8007229..ddf533d3bd3b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -25,6 +25,7 @@ from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, @@ -794,6 +795,42 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 12d52aa076d4..c8c6247960af 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -34,6 +34,7 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, + FusedAttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, @@ -681,7 +682,6 @@ def _get_add_time_ids( add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) @@ -692,6 +692,7 @@ def upcast_vae(self): XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need @@ -729,6 +730,65 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): """ diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b14c746f9962..644948ddc0d3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -24,6 +24,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, + FusedAttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, @@ -610,6 +611,7 @@ def upcast_vae(self): XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 0a2f1ca17cb0..8ac63636df86 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -10,10 +10,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.activations import get_activation -from ...models.attention import Attention from ...models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnAddedKVProcessor2_0, @@ -1000,6 +1000,42 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 53dc2ae15432..cef2c4113a48 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -191,10 +191,11 @@ def __init__( @property def init_noise_sigma(self): # standard deviation of the initial noise distribution + max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() if self.config.timestep_spacing in ["linspace", "trailing"]: - return self.sigmas.max() + return max_sigma - return (self.sigmas.max() ** 2 + 1) ** 0.5 + return (max_sigma**2 + 1) ** 0.5 @property def step_index(self): @@ -289,6 +290,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if sigmas.device.type == "cuda": + self.sigmas = self.sigmas.tolist() self._step_index = None def _sigma_to_t(self, sigma, log_sigmas): diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 59f0c0151d3a..280030d94b7c 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -938,6 +938,37 @@ def test_stable_diffusion_xl_save_from_pretrained(self): assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_stable_diffusion_xl_with_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + sd_pipe.fuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + sd_pipe.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + @slow class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): From f90a5139a244d440b9a85b419a52cc0155f32f98 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 6 Dec 2023 06:03:58 +0000 Subject: [PATCH 2/2] fix --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 6960ba6c4516..761391189f8f 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -446,8 +446,9 @@ def convert_ldm_unet_checkpoint( new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] # Relevant to StableDiffusionUpscalePipeline - if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): - new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]