diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index f6ee09f124be..0ef49c3e0ec4 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech ## AttnProcessor2_0 [[autodoc]] models.attention_processor.AttnProcessor2_0 +## FusedAttnProcessor2_0 +[[autodoc]] models.attention_processor.FusedAttnProcessor2_0 + ## LoRAAttnProcessor [[autodoc]] models.attention_processor.LoRAAttnProcessor diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 2bf3cc8f7c9c..df5477d0d643 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -123,16 +123,26 @@ def save_model_card( """ trigger_str = f"You should use {instance_prompt} to trigger the image generation." + diffusers_imports_pivotal = "" + diffusers_example_pivotal = "" if train_text_encoder_ti: trigger_str = ( "To trigger image generation of trained concept(or concepts) replace each concept identifier " "in you prompt with the new inserted tokens:\n" ) + diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + """ + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model") +state_dict = load_file(embedding_path) +pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) + """ if token_abstraction_dict: for key, value in token_abstraction_dict.items(): tokens = "".join(value) trigger_str += f""" -to trigger concept `{key}->` use `{tokens}` in your prompt \n +to trigger concept `{key}` → use `{tokens}` in your prompt \n """ yaml = f""" @@ -172,7 +182,21 @@ def save_model_card( {trigger_str} -## Download model +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +{diffusers_imports_pivotal} +pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +{diffusers_example_pivotal} +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke) Weights for this model are available in Safetensors format. @@ -791,6 +815,12 @@ def __init__( instance_data_root, instance_prompt, class_prompt, + dataset_name, + dataset_config_name, + cache_dir, + image_column, + caption_column, + train_text_encoder_ti, class_data_root=None, class_num=None, token_abstraction_dict=None, # token mapping for textual inversion @@ -805,10 +835,10 @@ def __init__( self.custom_instance_prompts = None self.class_prompt = class_prompt self.token_abstraction_dict = token_abstraction_dict - + self.train_text_encoder_ti = train_text_encoder_ti # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset - if args.dataset_name is not None: + if dataset_name is not None: try: from datasets import load_dataset except ImportError: @@ -821,26 +851,25 @@ def __init__( # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, + dataset_name, + dataset_config_name, + cache_dir=cache_dir, ) # Preprocessing the datasets. column_names = dataset["train"].column_names # 6. Get the column names for input/target. - if args.image_column is None: + if image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") else: - image_column = args.image_column if image_column not in column_names: raise ValueError( - f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) instance_images = dataset["train"][image_column] - if args.caption_column is None: + if caption_column is None: logger.info( "No caption column provided, defaulting to instance_prompt for all images. If your dataset " "contains captions/prompts for the images, make sure to specify the " @@ -848,11 +877,11 @@ def __init__( ) self.custom_instance_prompts = None else: - if args.caption_column not in column_names: + if caption_column not in column_names: raise ValueError( - f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - custom_instance_prompts = dataset["train"][args.caption_column] + custom_instance_prompts = dataset["train"][caption_column] # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: @@ -907,7 +936,7 @@ def __getitem__(self, index): if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] if caption: - if args.train_text_encoder_ti: + if self.train_text_encoder_ti: # replace instances of --token_abstraction in caption with the new tokens: "" etc. for token_abs, token_replacement in self.token_abstraction_dict.items(): caption = caption.replace(token_abs, "".join(token_replacement)) @@ -1093,10 +1122,10 @@ def main(args): if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + model_id = args.hub_model_id or Path(args.output_dir).name + repo_id = None if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id + repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( @@ -1464,6 +1493,12 @@ def load_model_hook(models, input_dir): instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_prompt=args.class_prompt, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + cache_dir=args.cache_dir, + image_column=args.image_column, + train_text_encoder_ti=args.train_text_encoder_ti, + caption_column=args.caption_column, class_data_root=args.class_data_dir if args.with_prior_preservation else None, token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, class_num=args.num_class_images, @@ -2004,23 +2039,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): } ) - if args.push_to_hub: - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/embeddings.safetensors", - ) - save_model_card( - repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=train_dataset.token_abstraction_dict, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/embeddings.safetensors", ) + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + if args.push_to_hub: upload_folder( repo_id=repo_id, folder_path=args.output_dir, diff --git a/examples/community/README.md b/examples/community/README.md index 9fad6ecbf690..98780edeccf7 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2870,10 +2870,14 @@ The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion). - `show_image` (`bool`, defaults to False): Determine whether to show intermediate results during generation. ``` -from pipeline_demofusion_sdxl import DemoFusionSDXLPipeline +from diffusers import DiffusionPipeline -model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" -pipe = DemoFusionSDXLPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) +pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + custom_pipeline="pipeline_demofusion_sdxl", + custom_revision="main", + torch_dtype=torch.float16, +) pipe = pipe.to("cuda") prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified." diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py index 5a81320219a5..98508b7ff89c 100644 --- a/examples/community/pipeline_demofusion_sdxl.py +++ b/examples/community/pipeline_demofusion_sdxl.py @@ -36,7 +36,9 @@ if is_invisible_watermark_available(): - from .watermark import StableDiffusionXLWatermarker + from diffusers.pipelines.stable_diffusion_xl.watermark import ( + StableDiffusionXLWatermarker, + ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/setup.py b/setup.py index bf0d40f77abe..ddb2afd64c51 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,7 @@ "pytest-timeout", "pytest-xdist", "python>=3.8.0", - "ruff>=0.1.5,<=0.2", + "ruff==0.1.5", "safetensors>=0.3.1", "sentencepiece>=0.1.91,!=0.1.92", "GitPython<3.1.19", diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index ab9212458895..7891984b0c5d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -30,7 +30,7 @@ "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "python": "python>=3.8.0", - "ruff": "ruff>=0.1.5,<=0.2", + "ruff": "ruff==0.1.5", "safetensors": "safetensors>=0.3.1", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "GitPython": "GitPython<3.1.19", diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 8c63c4cf59a5..bf100e7f2c81 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -282,7 +282,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): ) if torch_dtype is not None: - pipe.to(torch_dtype=torch_dtype) + pipe.to(dtype=torch_dtype) return pipe diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 40a335527ace..23a3e2bb3791 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -113,12 +113,14 @@ def __init__( ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout + self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim # we make use of this private variable to know whether this class is loaded @@ -180,6 +182,7 @@ def __init__( else: linear_cls = LoRACompatibleLinear + self.linear_cls = linear_cls self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: @@ -692,6 +695,32 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor return encoder_hidden_states + @torch.no_grad() + def fuse_projections(self, fuse=True): + is_cross_attention = self.cross_attention_dim != self.query_dim + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + + self.fused_projections = fuse + class AttnProcessor: r""" @@ -1184,9 +1213,6 @@ def __call__( scale: float = 1.0, ) -> torch.FloatTensor: residual = hidden_states - - args = () if USE_PEFT_BACKEND else (scale,) - if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1253,6 +1279,103 @@ def __call__( return hidden_states +class FusedAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is currently 🧪 experimental in nature and can change in future. + + + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states, *args) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states, *args) + + kv = attn.to_kv(encoder_hidden_states, *args) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class CustomDiffusionXFormersAttnProcessor(nn.Module): r""" Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. @@ -2251,6 +2374,7 @@ def __call__( AttentionProcessor = Union[ AttnProcessor, AttnProcessor2_0, + FusedAttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor, AttnAddedKVProcessor, diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 464bff9189dd..8fa3574125f9 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -22,6 +22,7 @@ from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, @@ -448,3 +449,41 @@ def forward( return (dec,) return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index dd91d8007229..ddf533d3bd3b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -25,6 +25,7 @@ from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, @@ -794,6 +795,42 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, sample: torch.FloatTensor, diff --git a/tests/convert_kandinsky3_unet.py b/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py similarity index 100% rename from tests/convert_kandinsky3_unet.py rename to src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 30eb8d2ceafc..761391189f8f 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -447,7 +447,8 @@ def convert_ldm_unet_checkpoint( # Relevant to StableDiffusionUpscalePipeline if "num_class_embeds" in config: - new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 12d52aa076d4..c8c6247960af 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -34,6 +34,7 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, + FusedAttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, @@ -681,7 +682,6 @@ def _get_add_time_ids( add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) @@ -692,6 +692,7 @@ def upcast_vae(self): XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need @@ -729,6 +730,65 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): """ diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b14c746f9962..644948ddc0d3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -24,6 +24,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, + FusedAttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, @@ -610,6 +611,7 @@ def upcast_vae(self): XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 0a2f1ca17cb0..8ac63636df86 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -10,10 +10,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.activations import get_activation -from ...models.attention import Attention from ...models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + Attention, AttentionProcessor, AttnAddedKVProcessor, AttnAddedKVProcessor2_0, @@ -1000,6 +1000,42 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 53dc2ae15432..cef2c4113a48 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -191,10 +191,11 @@ def __init__( @property def init_noise_sigma(self): # standard deviation of the initial noise distribution + max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() if self.config.timestep_spacing in ["linspace", "trailing"]: - return self.sigmas.max() + return max_sigma - return (self.sigmas.max() ** 2 + 1) ** 0.5 + return (max_sigma**2 + 1) ** 0.5 @property def step_index(self): @@ -289,6 +290,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if sigmas.device.type == "cuda": + self.sigmas = self.sigmas.tolist() self._step_index = None def _sigma_to_t(self, sigma, log_sigmas): diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 2998f7dc429e..14b89c3cd3b9 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -17,7 +17,7 @@ from distutils.util import strtobool from io import BytesIO, StringIO from pathlib import Path -from typing import List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL.Image @@ -58,6 +58,17 @@ if is_torch_available(): import torch + # Set a backend environment variable for any extra module import required for a custom accelerator + if "DIFFUSERS_TEST_BACKEND" in os.environ: + backend = os.environ["DIFFUSERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \ + to enable a specified backend.):\n{e}" + ) from e + if "DIFFUSERS_TEST_DEVICE" in os.environ: torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] try: @@ -210,6 +221,36 @@ def require_torch_gpu(test_case): ) +# These decorators are for accelerator-specific behaviours that are not GPU-specific +def require_torch_accelerator(test_case): + """Decorator marking a test that requires an accelerator backend and PyTorch.""" + return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")( + test_case + ) + + +def require_torch_accelerator_with_fp16(test_case): + """Decorator marking a test that requires an accelerator with support for the FP16 data type.""" + return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")( + test_case + ) + + +def require_torch_accelerator_with_fp64(test_case): + """Decorator marking a test that requires an accelerator with support for the FP64 data type.""" + return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")( + test_case + ) + + +def require_torch_accelerator_with_training(test_case): + """Decorator marking a test that requires an accelerator with support for training.""" + return unittest.skipUnless( + is_torch_available() and backend_supports_training(torch_device), + "test requires accelerator with training support", + )(test_case) + + def skip_mps(test_case): """Decorator marking a test to skip if torch_device is 'mps'""" return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) @@ -766,3 +807,139 @@ def disable_full_determinism(): os.environ["CUDA_LAUNCH_BLOCKING"] = "0" os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" torch.use_deterministic_algorithms(False) + + +# Utils for custom and alternative accelerator devices +def _is_torch_fp16_available(device): + if not is_torch_available(): + return False + + import torch + + device = torch.device(device) + + try: + x = torch.zeros((2, 2), dtype=torch.float16).to(device) + _ = x @ x + except Exception as e: + if device.type == "cuda": + raise ValueError( + f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}" + ) + + return False + + +def _is_torch_fp64_available(device): + if not is_torch_available(): + return False + + import torch + + try: + x = torch.zeros((2, 2), dtype=torch.float64).to(device) + _ = x @ x + except Exception as e: + if device.type == "cuda": + raise ValueError( + f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}" + ) + + return False + + +# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch +if is_torch_available(): + # Behaviour flags + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} + + # Function definitions + BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} + BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} + BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} + + +# This dispatches a defined function according to the accelerator from the function definitions. +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): + if device not in dispatch_table: + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values. Need to guard against 'None' instead at + # user level + if fn is None: + return None + + return fn(*args, **kwargs) + + +# These are callables which automatically dispatch the function specific to the accelerator +def backend_manual_seed(device: str, seed: int): + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) + + +def backend_empty_cache(device: str): + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) + + +def backend_device_count(device: str): + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) + + +# These are callables which return boolean behaviour flags and can be used to specify some +# device agnostic alternative where the feature is unsupported. +def backend_supports_training(device: str): + if not is_torch_available(): + return False + + if device not in BACKEND_SUPPORTS_TRAINING: + device = "default" + + return BACKEND_SUPPORTS_TRAINING[device] + + +# Guard for when Torch is not available +if is_torch_available(): + # Update device function dict mapping + def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): + try: + # Try to import the function directly + spec_fn = getattr(device_spec_module, attribute_name) + device_fn_dict[torch_device] = spec_fn + except AttributeError as e: + # If the function doesn't exist, and there is no default, throw an error + if "default" not in device_fn_dict: + raise AttributeError( + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." + ) from e + + if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ: + device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"] + if not Path(device_spec_path).is_file(): + raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}") + + try: + import_name = device_spec_path[: device_spec_path.index(".py")] + except ValueError as e: + raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e + + device_spec_module = importlib.import_module(import_name) + + try: + device_name = device_spec_module.DEVICE_NAME + except AttributeError: + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") + + if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name: + msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" + msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name." + raise ValueError(msg) + + torch_device = device_name + + # Add one entry here for each `BACKEND_*` dictionary. + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") + update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING") diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index 9d45d810f653..c6e3e19d4cc3 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -25,7 +25,11 @@ from diffusers.models.lora import LoRACompatibleLinear from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import ( + backend_manual_seed, + require_torch_accelerator_with_fp64, + torch_device, +) class EmbeddingsTests(unittest.TestCase): @@ -315,8 +319,7 @@ def test_restnet_with_kernel_sde_vp(self): class Transformer2DModelTests(unittest.TestCase): def test_spatial_transformer_default(self): torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) sample = torch.randn(1, 32, 64, 64).to(torch_device) spatial_transformer_block = Transformer2DModel( @@ -339,8 +342,7 @@ def test_spatial_transformer_default(self): def test_spatial_transformer_cross_attention_dim(self): torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) sample = torch.randn(1, 64, 64, 64).to(torch_device) spatial_transformer_block = Transformer2DModel( @@ -363,8 +365,7 @@ def test_spatial_transformer_cross_attention_dim(self): def test_spatial_transformer_timestep(self): torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) num_embeds_ada_norm = 5 @@ -401,8 +402,7 @@ def test_spatial_transformer_timestep(self): def test_spatial_transformer_dropout(self): torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) sample = torch.randn(1, 32, 64, 64).to(torch_device) spatial_transformer_block = ( @@ -427,11 +427,10 @@ def test_spatial_transformer_dropout(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - @unittest.skipIf(torch_device == "mps", "MPS does not support float64") + @require_torch_accelerator_with_fp64 def test_spatial_transformer_discrete(self): torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) num_embed = 5 diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 961147839461..5ea0d910f3a3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -35,6 +35,7 @@ CaptureLogger, require_python39_or_higher, require_torch_2, + require_torch_accelerator_with_training, require_torch_gpu, run_test_in_subprocess, torch_device, @@ -536,7 +537,7 @@ def test_model_from_pretrained(self): self.assertEqual(output_1.shape, output_2.shape) - @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") + @require_torch_accelerator_with_training def test_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -553,7 +554,7 @@ def test_training(self): loss = torch.nn.functional.mse_loss(output, noise) loss.backward() - @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") + @require_torch_accelerator_with_training def test_ema_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -624,7 +625,7 @@ def recursive_check(tuple_object, dict_object): recursive_check(outputs_tuple, outputs_dict) - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + @require_torch_accelerator_with_training def test_enable_disable_gradient_checkpointing(self): if not self.model_class._supports_gradient_checkpointing: return # Skip test if model does not support gradient checkpointing diff --git a/tests/models/test_models_prior.py b/tests/models/test_models_prior.py index 9b02de463ecd..ca27b6ff057f 100644 --- a/tests/models/test_models_prior.py +++ b/tests/models/test_models_prior.py @@ -21,7 +21,14 @@ from parameterized import parameterized from diffusers import PriorTransformer -from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, slow, torch_all_close, torch_device +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + floats_tensor, + slow, + torch_all_close, + torch_device, +) from .test_modeling_common import ModelTesterMixin @@ -157,7 +164,7 @@ def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache() @parameterized.expand( [ diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 5803e5bfda2a..aad496416508 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -18,7 +18,12 @@ import torch from diffusers import UNet1DModel -from diffusers.utils.testing_utils import floats_tensor, slow, torch_device +from diffusers.utils.testing_utils import ( + backend_manual_seed, + floats_tensor, + slow, + torch_device, +) from .test_modeling_common import ModelTesterMixin, UNetTesterMixin @@ -103,8 +108,7 @@ def test_from_pretrained_hub(self): def test_output_pretrained(self): model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) num_features = model.config.in_channels seq_len = 16 @@ -244,8 +248,7 @@ def test_output_pretrained(self): "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" ) torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) num_features = value_function.config.in_channels seq_len = 14 diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 4fd991b3fc46..2be343e9d627 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, + require_torch_accelerator, slow, torch_all_close, torch_device, @@ -153,7 +154,7 @@ def test_from_pretrained_hub(self): assert image is not None, "Make sure output is not None" - @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") + @require_torch_accelerator def test_from_pretrained_accelerate(self): model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model.to(torch_device) @@ -161,7 +162,7 @@ def test_from_pretrained_accelerate(self): assert image is not None, "Make sure output is not None" - @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") + @require_torch_accelerator def test_from_pretrained_accelerate_wont_change_results(self): # by defautl model loading will use accelerate as `low_cpu_mem_usage=True` model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 9ccd78f1fe47..80f59734b5ce 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -30,10 +30,15 @@ from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_hf_numpy, + require_torch_accelerator, + require_torch_accelerator_with_fp16, + require_torch_accelerator_with_training, require_torch_gpu, + skip_mps, slow, torch_all_close, torch_device, @@ -280,7 +285,7 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + @require_torch_accelerator_with_training def test_gradient_checkpointing(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -864,7 +869,7 @@ def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache() def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): dtype = torch.float16 if fp16 else torch.float32 @@ -882,6 +887,7 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): return model + @require_torch_gpu def test_set_attention_slice_auto(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -901,6 +907,7 @@ def test_set_attention_slice_auto(self): assert mem_bytes < 5 * 10**9 + @require_torch_gpu def test_set_attention_slice_max(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -920,6 +927,7 @@ def test_set_attention_slice_max(self): assert mem_bytes < 5 * 10**9 + @require_torch_gpu def test_set_attention_slice_int(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -939,6 +947,7 @@ def test_set_attention_slice_int(self): assert mem_bytes < 5 * 10**9 + @require_torch_gpu def test_set_attention_slice_list(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -975,7 +984,7 @@ def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator_with_fp16 def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4") latents = self.get_latents(seed) @@ -1003,7 +1012,7 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator_with_fp16 def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) latents = self.get_latents(seed, fp16=True) @@ -1031,7 +1040,8 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator + @skip_mps def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5") latents = self.get_latents(seed) @@ -1059,7 +1069,7 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator_with_fp16 def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True) latents = self.get_latents(seed, fp16=True) @@ -1087,7 +1097,8 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator + @skip_mps def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting") latents = self.get_latents(seed, shape=(4, 9, 64, 64)) @@ -1115,7 +1126,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator_with_fp16 def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True) latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) @@ -1143,7 +1154,7 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator_with_fp16 def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index aa755e387b61..df34a48da3aa 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -31,10 +31,15 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.loading_utils import load_image from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_hf_numpy, + require_torch_accelerator, + require_torch_accelerator_with_fp16, + require_torch_accelerator_with_training, require_torch_gpu, + skip_mps, slow, torch_all_close, torch_device, @@ -157,7 +162,7 @@ def test_forward_signature(self): def test_training(self): pass - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") + @require_torch_accelerator_with_training def test_gradient_checkpointing(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -213,10 +218,12 @@ def test_output_pretrained(self): model = model.to(torch_device) model.eval() - if torch_device == "mps": - generator = torch.manual_seed(0) + # Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors + generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda" + if torch_device != "mps": + generator = torch.Generator(device=generator_device).manual_seed(0) else: - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.manual_seed(0) image = torch.randn( 1, @@ -247,7 +254,7 @@ def test_output_pretrained(self): -9.8644e-03, ] ) - elif torch_device == "cpu": + elif generator_device == "cpu": expected_output_slice = torch.tensor( [ -0.1352, @@ -478,7 +485,7 @@ def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache() def get_file_format(self, seed, shape): return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" @@ -558,7 +565,7 @@ def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache() def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): dtype = torch.float16 if fp16 else torch.float32 @@ -580,9 +587,10 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False) return model def get_generator(self, seed=0): - if torch_device == "mps": - return torch.manual_seed(seed) - return torch.Generator(device=torch_device).manual_seed(seed) + generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda" + if torch_device != "mps": + return torch.Generator(device=generator_device).manual_seed(seed) + return torch.manual_seed(seed) @parameterized.expand( [ @@ -623,7 +631,7 @@ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator_with_fp16 def test_stable_diffusion_fp16(self, seed, expected_slice): model = self.get_sd_vae_model(fp16=True) image = self.get_sd_image(seed, fp16=True) @@ -677,7 +685,8 @@ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator + @skip_mps def test_stable_diffusion_decode(self, seed, expected_slice): model = self.get_sd_vae_model() encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) @@ -700,7 +709,7 @@ def test_stable_diffusion_decode(self, seed, expected_slice): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator_with_fp16 def test_stable_diffusion_decode_fp16(self, seed, expected_slice): model = self.get_sd_vae_model(fp16=True) encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True) @@ -811,7 +820,7 @@ def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache() def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): dtype = torch.float16 if fp16 else torch.float32 @@ -832,9 +841,10 @@ def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x return model def get_generator(self, seed=0): - if torch_device == "mps": - return torch.manual_seed(seed) - return torch.Generator(device=torch_device).manual_seed(seed) + generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda" + if torch_device != "mps": + return torch.Generator(device=generator_device).manual_seed(seed) + return torch.manual_seed(seed) @parameterized.expand( [ @@ -905,7 +915,8 @@ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps): # fmt: on ] ) - @require_torch_gpu + @require_torch_accelerator + @skip_mps def test_stable_diffusion_decode(self, seed, expected_slice): model = self.get_sd_vae_model() encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) diff --git a/tests/models/test_models_vq.py b/tests/models/test_models_vq.py index c7b9363b5d5f..a5a9288d6462 100644 --- a/tests/models/test_models_vq.py +++ b/tests/models/test_models_vq.py @@ -18,7 +18,12 @@ import torch from diffusers import VQModel -from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device +from diffusers.utils.testing_utils import ( + backend_manual_seed, + enable_full_determinism, + floats_tensor, + torch_device, +) from .test_modeling_common import ModelTesterMixin, UNetTesterMixin @@ -80,8 +85,7 @@ def test_output_pretrained(self): model.to(torch_device).eval() torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + backend_manual_seed(torch_device, 0) image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) diff --git a/tests/models/test_unet_blocks_common.py b/tests/models/test_unet_blocks_common.py index 4c399fdb74fa..9d1ddc2457e3 100644 --- a/tests/models/test_unet_blocks_common.py +++ b/tests/models/test_unet_blocks_common.py @@ -12,12 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest from typing import Tuple import torch -from diffusers.utils.testing_utils import floats_tensor, require_torch, torch_all_close, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + require_torch, + require_torch_accelerator_with_training, + torch_all_close, + torch_device, +) from diffusers.utils.torch_utils import randn_tensor @@ -104,7 +109,7 @@ def test_output(self, expected_slice): expected_slice = torch.tensor(expected_slice).to(torch_device) assert torch_all_close(output_slice.flatten(), expected_slice, atol=5e-3) - @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") + @require_torch_accelerator_with_training def test_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.block_class(**init_dict) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index ed295f792f99..7459d5a6b617 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -34,11 +34,14 @@ ) from diffusers.utils.testing_utils import ( CaptureLogger, + backend_empty_cache, enable_full_determinism, load_numpy, nightly, numpy_cosine_similarity_distance, + require_torch_accelerator, require_torch_gpu, + skip_mps, slow, torch_device, ) @@ -128,10 +131,12 @@ def get_dummy_components(self): return components def get_dummy_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) + generator_device = "cpu" if not device.startswith("cuda") else "cuda" + if not str(device).startswith("mps"): + generator = torch.Generator(device=generator_device).manual_seed(seed) else: - generator = torch.Generator(device=device).manual_seed(seed) + generator = torch.manual_seed(seed) + inputs = { "prompt": "A painting of a squirrel eating a burger", "generator": generator, @@ -299,15 +304,21 @@ def test_inference_batch_single_identical(self): @slow -@require_torch_gpu +@require_torch_accelerator +@skip_mps class StableDiffusion2PipelineSlowTests(unittest.TestCase): def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache() def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): - generator = torch.Generator(device=generator_device).manual_seed(seed) + _generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda" + if not str(device).startswith("mps"): + generator = torch.Generator(device=_generator_device).manual_seed(seed) + else: + generator = torch.manual_seed(seed) + latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64)) latents = torch.from_numpy(latents).to(device=device, dtype=dtype) inputs = { @@ -361,6 +372,7 @@ def test_stable_diffusion_k_lms(self): expected_slice = np.array([0.10440, 0.13115, 0.11100, 0.10141, 0.11440, 0.07215, 0.11332, 0.09693, 0.10006]) assert np.abs(image_slice - expected_slice).max() < 3e-3 + @require_torch_gpu def test_stable_diffusion_attention_slicing(self): torch.cuda.reset_peak_memory_stats() pipe = StableDiffusionPipeline.from_pretrained( @@ -432,6 +444,7 @@ def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: assert callback_fn.has_been_called assert number_of_steps == inputs["num_inference_steps"] + @require_torch_gpu def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -452,6 +465,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): # make sure that less than 2.8 GB is allocated assert mem_bytes < 2.8 * 10**9 + @require_torch_gpu def test_stable_diffusion_pipeline_with_model_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -511,15 +525,21 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): @nightly -@require_torch_gpu +@require_torch_accelerator +@skip_mps class StableDiffusion2PipelineNightlyTests(unittest.TestCase): def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache() def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): - generator = torch.Generator(device=generator_device).manual_seed(seed) + _generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda" + if not str(device).startswith("mps"): + generator = torch.Generator(device=_generator_device).manual_seed(seed) + else: + generator = torch.manual_seed(seed) + latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64)) latents = torch.from_numpy(latents).to(device=device, dtype=dtype) inputs = { diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 59f0c0151d3a..280030d94b7c 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -938,6 +938,37 @@ def test_stable_diffusion_xl_save_from_pretrained(self): assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_stable_diffusion_xl_with_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + sd_pipe.fuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + sd_pipe.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + @slow class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):