diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 0073cb793698..b865f6c33d51 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -66,32 +66,32 @@ body: Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): Questions on pipelines: - - Stable Diffusion @yiyixuxu @DN6 @sayakpaul @patrickvonplaten - - Stable Diffusion XL @yiyixuxu @sayakpaul @DN6 @patrickvonplaten - - Kandinsky @yiyixuxu @patrickvonplaten - - ControlNet @sayakpaul @yiyixuxu @DN6 @patrickvonplaten - - T2I Adapter @sayakpaul @yiyixuxu @DN6 @patrickvonplaten - - IF @DN6 @patrickvonplaten - - Text-to-Video / Video-to-Video @DN6 @sayakpaul @patrickvonplaten - - Wuerstchen @DN6 @patrickvonplaten + - Stable Diffusion @yiyixuxu @DN6 @sayakpaul + - Stable Diffusion XL @yiyixuxu @sayakpaul @DN6 + - Kandinsky @yiyixuxu + - ControlNet @sayakpaul @yiyixuxu @DN6 + - T2I Adapter @sayakpaul @yiyixuxu @DN6 + - IF @DN6 + - Text-to-Video / Video-to-Video @DN6 @sayakpaul + - Wuerstchen @DN6 - Other: @yiyixuxu @DN6 Questions on models: - - UNet @DN6 @yiyixuxu @sayakpaul @patrickvonplaten - - VAE @sayakpaul @DN6 @yiyixuxu @patrickvonplaten - - Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6 @patrickvonplaten + - UNet @DN6 @yiyixuxu @sayakpaul + - VAE @sayakpaul @DN6 @yiyixuxu + - Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6 - Questions on Schedulers: @yiyixuxu @patrickvonplaten + Questions on Schedulers: @yiyixuxu - Questions on LoRA: @sayakpaul @patrickvonplaten + Questions on LoRA: @sayakpaul - Questions on Textual Inversion: @sayakpaul @patrickvonplaten + Questions on Textual Inversion: @sayakpaul Questions on Training: - - DreamBooth @sayakpaul @patrickvonplaten - - Text-to-Image Fine-tuning @sayakpaul @patrickvonplaten - - Textual Inversion @sayakpaul @patrickvonplaten - - ControlNet @sayakpaul @patrickvonplaten + - DreamBooth @sayakpaul + - Text-to-Image Fine-tuning @sayakpaul + - Textual Inversion @sayakpaul + - ControlNet @sayakpaul Questions on Tests: @DN6 @sayakpaul @yiyixuxu @@ -99,7 +99,7 @@ body: Questions on JAX- and MPS-related things: @pcuenca - Questions on audio pipelines: @DN6 @patrickvonplaten + Questions on audio pipelines: @DN6 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 9ddb51753a59..a0337eaaaac5 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -38,13 +38,13 @@ members/contributors who may be interested in your PR. Core library: -- Schedulers: @yiyixuxu and @patrickvonplaten -- Pipelines: @patrickvonplaten and @sayakpaul -- Training examples: @sayakpaul and @patrickvonplaten -- Docs: @stevhliu and @yiyixuxu +- Schedulers: @yiyixuxu +- Pipelines: @sayakpaul @yiyixuxu @DN6 +- Training examples: @sayakpaul +- Docs: @stevhliu and @sayakpaul - JAX and MPS: @pcuenca - Audio: @sanchit-gandhi -- General functionalities: @patrickvonplaten and @sayakpaul +- General functionalities: @sayakpaul @yiyixuxu @DN6 Integrations: diff --git a/examples/community/README.md b/examples/community/README.md index d3223654c23a..f69a81c59baf 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -3561,14 +3561,17 @@ pipe.disable_style_aligned() This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results. +This pipeline relies on a "hack" discovered by the community that allows the generation of videos given an input image with AnimateDiff. It works by creating a copy of the image `num_frames` times and progressively adding more noise to the image based on the strength and latent interpolation method. + ```py import torch from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler from diffusers.utils import export_to_gif, load_image +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") -pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda") -pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace") +pipe = DiffusionPipeline.from_pretrained(model_id, motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda") +pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1) image = load_image("snail.png") output = pipe( diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index 1285e7c97a9b..5873ceaa8d70 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -24,7 +24,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel +from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.unets.unet_motion_model import MotionAdapter from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel @@ -382,6 +382,41 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -767,6 +802,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[PipelineImageInput] = None, conditioning_frames: Optional[List[PipelineImageInput]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -821,6 +857,9 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. conditioning_frames (`List[PipelineImageInput]`, *optional*): The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets are specified, images must be passed as a list such that each element of the list can be correctly @@ -965,9 +1004,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt + ) if isinstance(controlnet, ControlNetModel): conditioning_frames = self.prepare_image( @@ -1023,7 +1062,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] diff --git a/examples/community/pipeline_animatediff_img2video.py b/examples/community/pipeline_animatediff_img2video.py index 826742f9afc8..e77e26592d3e 100644 --- a/examples/community/pipeline_animatediff_img2video.py +++ b/examples/community/pipeline_animatediff_img2video.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Note: +# This pipeline relies on a "hack" discovered by the community that allows +# the generation of videos given an input image with AnimateDiff. It works +# by creating a copy of the image `num_frames` times and progressively adding +# more noise to the image based on the strength and latent interpolation method. import inspect -from dataclasses import dataclass from types import FunctionType from typing import Any, Callable, Dict, List, Optional, Union @@ -25,7 +30,8 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.unet_motion_model import MotionAdapter +from diffusers.models.unets.unet_motion_model import MotionAdapter +from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( DDIMScheduler, @@ -35,7 +41,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import randn_tensor @@ -48,9 +54,10 @@ >>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler >>> from diffusers.utils import export_to_gif, load_image + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") >>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda") - >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace") + >>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1) >>> image = load_image("snail.png") >>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp") @@ -225,14 +232,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -@dataclass -class AnimateDiffImgToVideoPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] - - class AnimateDiffImgToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): r""" - Pipeline for text-to-video generation. + Pipeline for image-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -503,6 +505,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -765,6 +802,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -818,6 +856,9 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -842,8 +883,8 @@ def __call__( Examples: Returns: - [`AnimateDiffImgToVideoPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput`] is + [`AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet @@ -902,12 +943,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt ) - if do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image, height=height, width=width) @@ -936,7 +974,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -970,7 +1012,7 @@ def __call__( callback(i, t, latents) if output_type == "latent": - return AnimateDiffImgToVideoPipelineOutput(frames=latents) + return AnimateDiffPipelineOutput(frames=latents) # 10. Post-processing video_tensor = self.decode_latents(latents) @@ -986,4 +1028,4 @@ def __call__( if not return_dict: return (video,) - return AnimateDiffImgToVideoPipelineOutput(frames=video) + return AnimateDiffPipelineOutput(frames=video) diff --git a/examples/controlnet/README_sdxl.md b/examples/controlnet/README_sdxl.md index 4a7797b9572c..fbbe9daac03f 100644 --- a/examples/controlnet/README_sdxl.md +++ b/examples/controlnet/README_sdxl.md @@ -113,7 +113,7 @@ pipe.enable_xformers_memory_efficient_attention() # memory optimization. pipe.enable_model_cpu_offload() -control_image = load_image("./conditioning_image_1.png") +control_image = load_image("./conditioning_image_1.png").resize((1024, 1024)) prompt = "pale golden rod circle with old lace background" # generate image @@ -128,4 +128,14 @@ image.save("./output.png") ### Specifying a better VAE -SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + +If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by: + +```diff ++ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16) +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) +pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + base_model_path, controlnet=controlnet, torch_dtype=torch.float16, ++ vae=vae, +) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index dcfc392e00d9..2d8613749877 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and import argparse +import contextlib +import gc import logging import math import os @@ -74,10 +76,15 @@ def image_grid(imgs, rows, cols): return grid -def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step): +def log_validation( + vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False +): logger.info("Running validation... ") - controlnet = accelerator.unwrap_model(controlnet) + if not is_final_validation: + controlnet = accelerator.unwrap_model(controlnet) + else: + controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) pipeline = StableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -118,6 +125,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler ) image_logs = [] + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") @@ -125,7 +133,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler images = [] for _ in range(args.num_validation_images): - with torch.autocast("cuda"): + with inference_ctx: image = pipeline( validation_prompt, validation_image, num_inference_steps=20, generator=generator ).images[0] @@ -136,6 +144,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} ) + tracker_key = "test" if is_final_validation else "validation" for tracker in accelerator.trackers: if tracker.name == "tensorboard": for log in image_logs: @@ -167,10 +176,14 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) - tracker.log({"validation": formatted_images}) + tracker.log({tracker_key: formatted_images}) else: logger.warn(f"image logging not implemented for {tracker.name}") + del pipeline + gc.collect() + torch.cuda.empty_cache() + return image_logs @@ -197,7 +210,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): img_str = "" if image_logs is not None: - img_str = "You can find some example images below.\n" + img_str = "You can find some example images below.\n\n" for i, log in enumerate(image_logs): images = log["images"] validation_prompt = log["validation_prompt"] @@ -1131,6 +1144,22 @@ def load_model_hook(models, input_dir): controlnet = unwrap_model(controlnet) controlnet.save_pretrained(args.output_dir) + # Run a final round of validation. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + if args.push_to_hub: save_model_card( repo_id, diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 38d21d4094a7..b03857597021 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import contextlib import functools import gc import logging @@ -65,20 +66,38 @@ logger = get_logger(__name__) -def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step): +def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): logger.info("Running validation... ") - controlnet = accelerator.unwrap_model(controlnet) + if not is_final_validation: + controlnet = accelerator.unwrap_model(controlnet) + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + else: + controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + if args.pretrained_vae_model_name_or_path is not None: + vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype) + else: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype + ) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) - pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - unet=unet, - controlnet=controlnet, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -106,6 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) ) image_logs = [] + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") @@ -114,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) images = [] for _ in range(args.num_validation_images): - with torch.autocast("cuda"): + with inference_ctx: image = pipeline( prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator ).images[0] @@ -124,6 +144,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} ) + tracker_key = "test" if is_final_validation else "validation" for tracker in accelerator.trackers: if tracker.name == "tensorboard": for log in image_logs: @@ -155,7 +176,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) - tracker.log({"validation": formatted_images}) + tracker.log({tracker_key: formatted_images}) else: logger.warn(f"image logging not implemented for {tracker.name}") @@ -189,7 +210,7 @@ def import_model_class_from_model_name_or_path( def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): img_str = "" if image_logs is not None: - img_str = "You can find some example images below.\n" + img_str = "You can find some example images below.\n\n" for i, log in enumerate(image_logs): images = log["images"] validation_prompt = log["validation_prompt"] @@ -1228,7 +1249,13 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer if args.validation_prompt is not None and global_step % args.validation_steps == 0: image_logs = log_validation( - vae, unet, controlnet, args, accelerator, weight_dtype, global_step + vae=vae, + unet=unet, + controlnet=controlnet, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -1244,6 +1271,21 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer controlnet = unwrap_model(controlnet) controlnet.save_pretrained(args.output_dir) + # Run a final round of validation. + # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae=None, + unet=None, + controlnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + if args.push_to_hub: save_model_card( repo_id, diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 146806efb60e..e42311fb8f9e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -67,6 +67,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.27.0.dev0") @@ -140,6 +143,61 @@ def save_model_card( model_card.save(os.path.join(repo_folder, "README.md")) +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + with torch.cuda.amp.autocast(): + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + return images + + def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): @@ -862,7 +920,6 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1615,10 +1672,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) # create pipeline if not args.train_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( @@ -1644,50 +1697,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): torch_dtype=weight_dtype, ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config( - pipeline.scheduler.config, **scheduler_args - ) - - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - with torch.cuda.amp.autocast(): - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + ) # Save the lora layers accelerator.wait_for_everyone() @@ -1733,45 +1751,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): torch_dtype=weight_dtype, ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - # load attention processors pipeline.load_lora_weights(args.output_dir) # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: - pipeline = pipeline.to(accelerator.device) - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + final_validation=True, + ) if args.push_to_hub: save_model_card( diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 2d77e9c8bfa3..78021b5afed4 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -951,6 +951,9 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler ) + if args.use_ema: + ema_unet.to(accelerator.device) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -1126,6 +1129,8 @@ def compute_time_ids(original_size, crops_coords_top_left): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 90f1b36b08ea..d35a11b70a20 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -546,6 +546,8 @@ def __getitem__(self, i): example["original_size"] = (image.height, image.width) + image = image.resize((self.size, self.size), resample=self.interpolation) + if self.center_crop: y1 = max(0, int(round((image.height - self.size) / 2.0))) x1 = max(0, int(round((image.width - self.size) / 2.0))) @@ -576,7 +578,6 @@ def __getitem__(self, i): img = np.array(image).astype(np.uint8) image = Image.fromarray(img) - image = image.resize((self.size, self.size), resample=self.interpolation) image = self.flip_transform(image) image = np.array(image).astype(np.uint8) diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py index d53d1f792e89..ce69bfe2bfc8 100755 --- a/scripts/convert_dance_diffusion_to_diffusers.py +++ b/scripts/convert_dance_diffusion_to_diffusers.py @@ -4,6 +4,7 @@ import os from copy import deepcopy +import requests import torch from audio_diffusion.models import DiffusionAttnUnet1D from diffusion import sampling @@ -73,9 +74,14 @@ def __init__(self, global_args): def download(model_name): url = MODELS_MAP[model_name]["url"] - os.system(f"wget {url} ./") + r = requests.get(url, stream=True) - return f"./{model_name}.ckpt" + local_filename = f"./{model_name}.ckpt" + with open(local_filename, "wb") as fp: + for chunk in r.iter_content(chunk_size=8192): + fp.write(chunk) + + return local_filename DOWN_NUM_TO_LAYER = { diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 0870f3a67a3d..964087e0e06a 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -106,6 +106,10 @@ def load_lora_weights( if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) @@ -1229,6 +1233,10 @@ def load_lora_weights( # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 346b44d1c553..1bbc96c6f5a7 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -361,16 +361,19 @@ def _unfuse_lora(self): self.w_down = None def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + if self.padding_mode != "zeros": + hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) + padding = (0, 0) + else: + padding = self.padding + + original_outputs = F.conv2d( + hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups + ) + if self.lora_layer is None: - # make sure to the functional Conv2D function as otherwise torch.compile's graph will break - # see: https://github.com/huggingface/diffusers/pull/4315 - return F.conv2d( - hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups - ) + return original_outputs else: - original_outputs = F.conv2d( - hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups - ) return original_outputs + (scale * self.lora_layer(hidden_states)) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 9cb0f42c85ef..246a4b8124d8 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -217,6 +217,7 @@ def __init__( use_motion_mid_block: int = True, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, ): super().__init__() @@ -252,9 +253,7 @@ def __init__( timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, + timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim ) if encoder_hid_dim_type is None: @@ -306,6 +305,7 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, ) @@ -321,6 +321,7 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c92df24251f4..adb32a782b3e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -436,7 +436,6 @@ def load_sub_model( variant: str, low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], - revision: str = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" # retrieve class candidates @@ -504,6 +503,7 @@ def load_sub_model( loading_kwargs["offload_folder"] = offload_folder loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) + if from_flax: loading_kwargs["from_flax"] = True @@ -1280,7 +1280,6 @@ def load_module(name, value): variant=variant, low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, - revision=revision, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 3aeeb239b613..67d28fe19e7e 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -26,11 +26,13 @@ from huggingface_hub import hf_hub_download from huggingface_hub.repocard import RepoCard from packaging import version +from safetensors.torch import load_file from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( AutoencoderKL, AutoPipelineForImage2Image, + AutoPipelineForText2Image, ControlNetModel, DDIMScheduler, DiffusionPipeline, @@ -1177,6 +1179,24 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): # Just makes sure it works.. _ = pipe(**inputs, generator=torch.manual_seed(0)).images + def test_modify_padding_mode(self): + def set_pad_mode(network, mode="circular"): + for _, module in network.named_modules(): + if isinstance(module, torch.nn.Conv2d): + module.padding_mode = mode + + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, _, _ = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _pad_mode = "circular" + set_pad_mode(pipe.vae, _pad_mode) + set_pad_mode(pipe.unet, _pad_mode) + + _, _, inputs = self.get_dummy_inputs() + _ = pipe(**inputs).images + class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): pipeline_class = StableDiffusionPipeline @@ -1727,6 +1747,40 @@ def test_load_unload_load_kohya_lora(self): self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) release_memory(pipe) + def test_not_empty_state_dict(self): + # Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again + pipe = AutoPipelineForText2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ).to("cuda") + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) + + cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors") + lcm_lora = load_file(cached_file) + + pipe.load_lora_weights(lcm_lora, adapter_name="lcm") + self.assertTrue(lcm_lora != {}) + release_memory(pipe) + + def test_load_unload_load_state_dict(self): + # Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again + pipe = AutoPipelineForText2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ).to("cuda") + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) + + cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors") + lcm_lora = load_file(cached_file) + previous_state_dict = lcm_lora.copy() + + pipe.load_lora_weights(lcm_lora, adapter_name="lcm") + self.assertDictEqual(lcm_lora, previous_state_dict) + + pipe.unload_lora_weights() + pipe.load_lora_weights(lcm_lora, adapter_name="lcm") + self.assertDictEqual(lcm_lora, previous_state_dict) + + release_memory(pipe) + @slow @require_torch_gpu