From 50296739878f3e17b2d25d45ef626318b44440b9 Mon Sep 17 00:00:00 2001 From: Jenyuan-Huang <112627523+DannHuang@users.noreply.github.com> Date: Mon, 29 Apr 2024 04:34:57 +0800 Subject: [PATCH 01/20] Update InstantStyle usage in IP-Adapter documentation (#7806) * enable control ip-adapter per-transformer block on-the-fly --------- Co-authored-by: sayakpaul Co-authored-by: ResearcherXman Co-authored-by: YiYi Xu --- docs/source/en/using-diffusers/ip_adapter.md | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index ea5f781c625d..02fb0c34aa79 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -661,16 +661,16 @@ image ### Style & layout control -[InstantStyle](https://arxiv.org/abs/2404.02733) is a plug-and-play method on top of IP-Adapter, which disentangles style and layout from image prompt to control image generation. This is achieved by only inserting IP-Adapters to some specific part of the model. +[InstantStyle](https://arxiv.org/abs/2404.02733) is a plug-and-play method on top of IP-Adapter, which disentangles style and layout from image prompt to control image generation. This way, you can generate images following only the style or layout from image prompt, with significantly improved diversity. This is achieved by only activating IP-Adapters to specific parts of the model. By default IP-Adapters are inserted to all layers of the model. Use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method with a dictionary to assign scales to IP-Adapter at different layers. ```py -from diffusers import AutoPipelineForImage2Image +from diffusers import AutoPipelineForText2Image from diffusers.utils import load_image import torch -pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda") +pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda") pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") scale = { @@ -680,15 +680,15 @@ scale = { pipeline.set_ip_adapter_scale(scale) ``` -This will activate IP-Adapter at the second layer in the model's down-part block 2 and up-part block 0. The former is the layer where IP-Adapter injects layout information and the latter injects style. Inserting IP-Adapter to these two layers you can generate images following the style and layout of image prompt, but with contents more aligned to text prompt. +This will activate IP-Adapter at the second layer in the model's down-part block 2 and up-part block 0. The former is the layer where IP-Adapter injects layout information and the latter injects style. Inserting IP-Adapter to these two layers you can generate images following both the style and layout from image prompt, but with contents more aligned to text prompt. ```py style_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg") -generator = torch.Generator(device="cpu").manual_seed(42) +generator = torch.Generator(device="cpu").manual_seed(26) image = pipeline( prompt="a cat, masterpiece, best quality, high quality", - image=style_image, + ip_adapter_image=style_image, negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", guidance_scale=5, num_inference_steps=30, @@ -703,7 +703,7 @@ image
IP-Adapter image
- +
generated image
@@ -718,10 +718,10 @@ scale = { } pipeline.set_ip_adapter_scale(scale) -generator = torch.Generator(device="cpu").manual_seed(42) +generator = torch.Generator(device="cpu").manual_seed(26) image = pipeline( prompt="a cat, masterpiece, best quality, high quality", - image=style_image, + ip_adapter_image=style_image, negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", guidance_scale=5, num_inference_steps=30, @@ -732,11 +732,11 @@ image
- +
IP-Adapter only in style layer
- +
IP-Adapter in all layers
From 235d34cf567e78bf958344d3132bb018a8580295 Mon Sep 17 00:00:00 2001 From: Nilesh Date: Mon, 29 Apr 2024 06:23:29 +0530 Subject: [PATCH 02/20] Check for latents, before calling prepare_latents - sdxlImg2Img (#7582) * Check for latents, before calling prepare_latents - sdxlImg2Img * Added latents check for all the img2img pipeline * Fixed silly mistake while checking latents as None --- .../clip_guided_stable_diffusion_img2img.py | 13 ++++++++--- .../community/latent_consistency_img2img.py | 23 ++++++++++--------- .../stable_diffusion_controlnet_img2img.py | 19 +++++++-------- ...le_diffusion_controlnet_inpaint_img2img.py | 19 +++++++-------- .../controlnet/pipeline_controlnet_img2img.py | 19 +++++++-------- .../pipeline_controlnet_sd_xl_img2img.py | 21 +++++++++-------- .../pipeline_latent_consistency_img2img.py | 7 +++--- .../shap_e/pipeline_shap_e_img2img.py | 18 +++++++-------- .../pipeline_stable_unclip_img2img.py | 21 +++++++++-------- .../pipeline_stable_diffusion_xl_img2img.py | 22 ++++++++++-------- 10 files changed, 99 insertions(+), 83 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion_img2img.py b/examples/community/clip_guided_stable_diffusion_img2img.py index 434d5253679a..c8e0a9094f22 100644 --- a/examples/community/clip_guided_stable_diffusion_img2img.py +++ b/examples/community/clip_guided_stable_diffusion_img2img.py @@ -359,9 +359,16 @@ def __call__( # Preprocess image image = preprocess(image, width, height) - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + text_embeddings.dtype, + self.device, + generator, + ) if clip_guidance_scale > 0: if clip_prompt is not None: diff --git a/examples/community/latent_consistency_img2img.py b/examples/community/latent_consistency_img2img.py index 35cd74166c68..98078a2eef96 100644 --- a/examples/community/latent_consistency_img2img.py +++ b/examples/community/latent_consistency_img2img.py @@ -335,17 +335,18 @@ def __call__( # 5. Prepare latent variable num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - image, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - latents, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + latents, + ) bs = batch_size * num_images_per_prompt # 6. Get Guidance Scale Embedding diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 5f9083616a84..74674e65f0ef 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -802,15 +802,16 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - generator, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py index d056eb112165..14c4e4aa6d4e 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py @@ -907,15 +907,16 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - generator, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) mask_image_latents = self.prepare_mask_latents( mask_image, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index a5a0aaed0f2e..022f30d819d8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -1169,15 +1169,16 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - generator, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index d32e7d81649d..d7889a9efbb5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -1429,16 +1429,17 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - generator, - True, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 8957d7140ef1..fce694d1d0bd 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -872,9 +872,10 @@ def __call__( else self.scheduler.config.original_inference_steps ) latent_timestep = timesteps[:1] - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator - ) + if latents is None: + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) bs = batch_size * num_images_per_prompt # 6. Get Guidance Scale Embedding diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py index 02e32633cedb..700ca5db6f07 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -239,15 +239,15 @@ def __call__( num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim - - latents = self.prepare_latents( - (batch_size, num_embeddings * embedding_dim), - image_embeds.dtype, - device, - generator, - latents, - self.scheduler, - ) + if latents is None: + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index fe19b4de3127..134ec39effc5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -786,16 +786,17 @@ def __call__( # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size=batch_size, - num_channels_latents=num_channels_latents, - height=height, - width=width, - dtype=prompt_embeds.dtype, - device=device, - generator=generator, - latents=latents, - ) + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b72b19d5c1ef..b98ea279c1a2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1247,17 +1247,19 @@ def denoising_value_valid(dnv): latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) add_noise = True if self.denoising_start is None else False + # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - generator, - add_noise, - ) + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) From b1c5817a896ff59604f5ab2b3334df8c5c71ff5b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 29 Apr 2024 13:44:39 +0530 Subject: [PATCH 03/20] Add debugging workflow (#7778) add debug workflow Co-authored-by: Sayak Paul --- .github/workflows/ssh-runner.yml | 55 ++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 .github/workflows/ssh-runner.yml diff --git a/.github/workflows/ssh-runner.yml b/.github/workflows/ssh-runner.yml new file mode 100644 index 000000000000..befebfbc9b96 --- /dev/null +++ b/.github/workflows/ssh-runner.yml @@ -0,0 +1,55 @@ +name: SSH into runners + +on: + workflow_dispatch: + inputs: + runner_type: + description: 'Type of runner to test (a10 or t4)' + required: true + docker_image: + description: 'Name of the Docker image' + required: true + +env: + IS_GITHUB_CI: "1" + HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} + HF_HOME: /mnt/cache + DIFFUSERS_IS_CI: yes + OMP_NUM_THREADS: 8 + MKL_NUM_THREADS: 8 + RUN_SLOW: yes + +jobs: + ssh_runner: + name: "SSH" + runs-on: [single-gpu, nvidia-gpu, "${{ github.event.inputs.runner_type }}", ci] + container: + image: ${{ github.event.inputs.docker_image }} + options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ + + steps: + - name: Update clone + working-directory: /diffusers + run: | + git fetch && git checkout ${{ github.sha }} + - name: Cleanup + working-directory: /diffusers + run: | + rm -rf tests/__pycache__ + rm -rf tests/models/__pycache__ + rm -rf reports + - name: Show installed libraries and their versions + working-directory: /diffusers + run: pip freeze + + - name: NVIDIA-SMI + run: | + nvidia-smi + + - name: Tailscale # In order to be able to SSH when a test fails + uses: huggingface/tailscale-action@v1 + with: + authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }} + slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }} + slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + waitForSSH: true From a38dd795120e1884e3396d41bf44e44fd9b1eba0 Mon Sep 17 00:00:00 2001 From: Yushu Date: Mon, 29 Apr 2024 03:54:16 -0700 Subject: [PATCH 04/20] [Pipeline] Fix error of SVD pipeline when num_videos_per_prompt > 1 (#7786) swap the order for do_classifier_free_guidance concat with repeat Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- .../pipeline_stable_video_diffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 070183b92409..da6832cebd4d 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -199,6 +199,9 @@ def _encode_vae_image( image = image.to(device=device) image_latents = self.vae.encode(image).latent_dist.mode() + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: negative_image_latents = torch.zeros_like(image_latents) @@ -207,9 +210,6 @@ def _encode_vae_image( # to avoid doing two forward passes image_latents = torch.cat([negative_image_latents, image_latents]) - # duplicate image_latents for each generation per prompt, using mps friendly method - image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) - return image_latents def _get_add_time_ids( From eb96ff0d5952f6d64b09bc51a2115de1898e9210 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 29 Apr 2024 17:36:50 +0530 Subject: [PATCH 05/20] Safetensor loading in AnimateDiff conversion scripts (#7764) * update * update --- scripts/convert_animatediff_motion_lora_to_diffusers.py | 7 +++++-- scripts/convert_animatediff_motion_module_to_diffusers.py | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/scripts/convert_animatediff_motion_lora_to_diffusers.py b/scripts/convert_animatediff_motion_lora_to_diffusers.py index 509a7345793c..c680fdc68462 100644 --- a/scripts/convert_animatediff_motion_lora_to_diffusers.py +++ b/scripts/convert_animatediff_motion_lora_to_diffusers.py @@ -1,7 +1,7 @@ import argparse import torch -from safetensors.torch import save_file +from safetensors.torch import load_file, save_file def convert_motion_module(original_state_dict): @@ -34,7 +34,10 @@ def get_args(): if __name__ == "__main__": args = get_args() - state_dict = torch.load(args.ckpt_path, map_location="cpu") + if args.ckpt_path.endswith(".safetensors"): + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] diff --git a/scripts/convert_animatediff_motion_module_to_diffusers.py b/scripts/convert_animatediff_motion_module_to_diffusers.py index ceb967acd3d6..e8fb007243fd 100644 --- a/scripts/convert_animatediff_motion_module_to_diffusers.py +++ b/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -1,6 +1,7 @@ import argparse import torch +from safetensors.torch import load_file from diffusers import MotionAdapter @@ -38,7 +39,11 @@ def get_args(): if __name__ == "__main__": args = get_args() - state_dict = torch.load(args.ckpt_path, map_location="cpu") + if args.ckpt_path.endswith(".safetensors"): + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] From 8af793b2d467d0a28f9fed6e07aedd7dc2b9a0ba Mon Sep 17 00:00:00 2001 From: jschoormans Date: Mon, 29 Apr 2024 21:00:53 +0200 Subject: [PATCH 06/20] Adding TextualInversionLoaderMixin for the controlnet_inpaint_sd_xl pipeline (#7288) * added TextualInversionMixIn to controlnet_inpaint_sd_xl pipeline --------- Co-authored-by: YiYi Xu --- .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 18c4370b8025..b9c4e3c0032c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -151,7 +151,12 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLControlNetInpaintPipeline( - DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin + DiffusionPipeline, + StableDiffusionMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + TextualInversionLoaderMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -160,6 +165,7 @@ class StableDiffusionXLControlNetInpaintPipeline( library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files From 83ae24ce2d9a080e850c630a1c7050f60e63e3e3 Mon Sep 17 00:00:00 2001 From: RuiningLi <88520323+RuiningLi@users.noreply.github.com> Date: Mon, 29 Apr 2024 21:32:13 +0100 Subject: [PATCH 07/20] Added get_velocity function to EulerDiscreteScheduler. (#7733) * Added get_velocity function to EulerDiscreteScheduler. * Fix white space on blank lines * Added copied from statement * back to the original. --------- Co-authored-by: Ruining Li Co-authored-by: Sayak Paul --- .../schedulers/scheduling_euler_discrete.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 1e3252c0bd39..be5bbc235878 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -576,5 +576,44 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor + ) -> torch.FloatTensor: + if ( + isinstance(timesteps, int) + or isinstance(timesteps, torch.IntTensor) + or isinstance(timesteps, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if sample.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timesteps = timesteps.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timesteps = timesteps.to(sample.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + alphas_cumprod = self.alphas_cumprod.to(sample) + sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + def __len__(self): return self.config.num_train_timesteps From f53352f750725a4bf4a44220db196c0f26f3ff81 Mon Sep 17 00:00:00 2001 From: Clint Adams <223406+clinty@users.noreply.github.com> Date: Mon, 29 Apr 2024 17:39:59 -0400 Subject: [PATCH 08/20] Set main_input_name in StableDiffusionSafetyChecker to "clip_input" (#7500) FlaxStableDiffusionSafetyChecker sets main_input_name to "clip_input". This makes StableDiffusionSafetyChecker consistent. Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 6cc4d26f29b4..3e6dec3e0bff 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -31,6 +31,7 @@ def cosine_distance(image_embeds, text_embeds): class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig + main_input_name = "clip_input" _no_split_modules = ["CLIPEncoderLayer"] From 31d9f9ea77d7bda61484ef9a29d8453f88c6e28d Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Tue, 30 Apr 2024 07:54:38 +0530 Subject: [PATCH 09/20] [Tests] reduce the model size in the ddim fast test (#7803) chore: reducing model size for ddim fast pipeline Co-authored-by: Sayak Paul --- tests/pipelines/ddim/test_ddim.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 0f0654397a34..2078a592ceca 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -42,9 +42,10 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, + block_out_channels=(4, 8), + layers_per_block=1, + norm_num_groups=4, + sample_size=8, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -79,10 +80,8 @@ def test_inference(self): image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - self.assertEqual(image.shape, (1, 32, 32, 3)) - expected_slice = np.array( - [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] - ) + self.assertEqual(image.shape, (1, 8, 8, 3)) + expected_slice = np.array([0.0, 9.979e-01, 0.0, 9.999e-01, 9.986e-01, 9.991e-01, 7.106e-04, 0.0, 0.0]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) From 21f023ec1acefbe3efa470451838dab4c133e098 Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Tue, 30 Apr 2024 08:11:03 +0530 Subject: [PATCH 10/20] [Tests] reduce the model size in the ddpm fast test (#7797) * chore: reducing unet size for faster tests * review suggestions --------- Co-authored-by: Sayak Paul --- tests/pipelines/ddpm/test_ddpm.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index c0cce3a2f237..f6d0821da4c2 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -30,9 +30,10 @@ class DDPMPipelineFastTests(unittest.TestCase): def dummy_uncond_unet(self): torch.manual_seed(0) model = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, + block_out_channels=(4, 8), + layers_per_block=1, + norm_num_groups=4, + sample_size=8, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -58,10 +59,8 @@ def test_fast_inference(self): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array( - [9.956e-01, 5.785e-01, 4.675e-01, 9.930e-01, 0.0, 1.000, 1.199e-03, 2.648e-04, 5.101e-04] - ) + assert image.shape == (1, 8, 8, 3) + expected_slice = np.array([0.0, 0.9996672, 0.00329116, 1.0, 0.9995991, 1.0, 0.0060907, 0.00115037, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -83,7 +82,7 @@ def test_inference_predict_sample(self): image_slice = image[0, -3:, -3:, -1] image_eps_slice = image_eps[0, -3:, -3:, -1] - assert image.shape == (1, 32, 32, 3) + assert image.shape == (1, 8, 8, 3) tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance From b02e2113ff4625100a4412abd1ae0392ee415364 Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Tue, 30 Apr 2024 08:11:26 +0530 Subject: [PATCH 11/20] [Tests] reduce the model size in the amused fast test (#7804) * chore: reducing model sizes * chore: shrinks further * chore: shrinks further * chore: shrinking model for img2img pipeline * chore: reducing size of model for inpaint pipeline --------- Co-authored-by: Sayak Paul --- tests/pipelines/amused/test_amused.py | 36 +++++++++---------- tests/pipelines/amused/test_amused_img2img.py | 36 +++++++++---------- tests/pipelines/amused/test_amused_inpaint.py | 36 +++++++++---------- 3 files changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index f03751e2f830..9a9e2551d642 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -38,17 +38,17 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) transformer = UVit2DModel( - hidden_size=32, + hidden_size=8, use_bias=False, hidden_dropout=0.0, - cond_embed_dim=32, + cond_embed_dim=8, micro_cond_encode_dim=2, micro_cond_embed_dim=10, - encoder_hidden_size=32, + encoder_hidden_size=8, vocab_size=32, - codebook_size=32, - in_channels=32, - block_out_channels=32, + codebook_size=8, + in_channels=8, + block_out_channels=8, num_res_blocks=1, downsample=True, upsample=True, @@ -56,7 +56,7 @@ def get_dummy_components(self): num_hidden_layers=1, num_attention_heads=1, attention_dropout=0.0, - intermediate_size=32, + intermediate_size=8, layer_norm_eps=1e-06, ln_elementwise_affine=True, ) @@ -64,17 +64,17 @@ def get_dummy_components(self): torch.manual_seed(0) vqvae = VQModel( act_fn="silu", - block_out_channels=[32], + block_out_channels=[8], down_block_types=[ "DownEncoderBlock2D", ], in_channels=3, - latent_channels=32, - layers_per_block=2, - norm_num_groups=32, - num_vq_embeddings=32, + latent_channels=8, + layers_per_block=1, + norm_num_groups=8, + num_vq_embeddings=8, out_channels=3, - sample_size=32, + sample_size=8, up_block_types=[ "UpDecoderBlock2D", ], @@ -85,14 +85,14 @@ def get_dummy_components(self): text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, - intermediate_size=64, + hidden_size=8, + intermediate_size=8, layer_norm_eps=1e-05, - num_attention_heads=8, - num_hidden_layers=3, + num_attention_heads=1, + num_hidden_layers=1, pad_token_id=1, vocab_size=1000, - projection_dim=32, + projection_dim=8, ) text_encoder = CLIPTextModelWithProjection(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py index efbca1f437a4..24bc34d330e9 100644 --- a/tests/pipelines/amused/test_amused_img2img.py +++ b/tests/pipelines/amused/test_amused_img2img.py @@ -42,17 +42,17 @@ class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) transformer = UVit2DModel( - hidden_size=32, + hidden_size=8, use_bias=False, hidden_dropout=0.0, - cond_embed_dim=32, + cond_embed_dim=8, micro_cond_encode_dim=2, micro_cond_embed_dim=10, - encoder_hidden_size=32, + encoder_hidden_size=8, vocab_size=32, - codebook_size=32, - in_channels=32, - block_out_channels=32, + codebook_size=8, + in_channels=8, + block_out_channels=8, num_res_blocks=1, downsample=True, upsample=True, @@ -60,7 +60,7 @@ def get_dummy_components(self): num_hidden_layers=1, num_attention_heads=1, attention_dropout=0.0, - intermediate_size=32, + intermediate_size=8, layer_norm_eps=1e-06, ln_elementwise_affine=True, ) @@ -68,17 +68,17 @@ def get_dummy_components(self): torch.manual_seed(0) vqvae = VQModel( act_fn="silu", - block_out_channels=[32], + block_out_channels=[8], down_block_types=[ "DownEncoderBlock2D", ], in_channels=3, - latent_channels=32, - layers_per_block=2, - norm_num_groups=32, - num_vq_embeddings=32, + latent_channels=8, + layers_per_block=1, + norm_num_groups=8, + num_vq_embeddings=32, # reducing this to 16 or 8 -> RuntimeError: "cdist_cuda" not implemented for 'Half' out_channels=3, - sample_size=32, + sample_size=8, up_block_types=[ "UpDecoderBlock2D", ], @@ -89,14 +89,14 @@ def get_dummy_components(self): text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, - intermediate_size=64, + hidden_size=8, + intermediate_size=8, layer_norm_eps=1e-05, - num_attention_heads=8, - num_hidden_layers=3, + num_attention_heads=1, + num_hidden_layers=1, pad_token_id=1, vocab_size=1000, - projection_dim=32, + projection_dim=8, ) text_encoder = CLIPTextModelWithProjection(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py index d397f8d81297..d0c1ed09c706 100644 --- a/tests/pipelines/amused/test_amused_inpaint.py +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -42,17 +42,17 @@ class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) transformer = UVit2DModel( - hidden_size=32, + hidden_size=8, use_bias=False, hidden_dropout=0.0, - cond_embed_dim=32, + cond_embed_dim=8, micro_cond_encode_dim=2, micro_cond_embed_dim=10, - encoder_hidden_size=32, + encoder_hidden_size=8, vocab_size=32, - codebook_size=32, - in_channels=32, - block_out_channels=32, + codebook_size=32, # codebook size needs to be consistent with num_vq_embeddings for inpaint tests + in_channels=8, + block_out_channels=8, num_res_blocks=1, downsample=True, upsample=True, @@ -60,7 +60,7 @@ def get_dummy_components(self): num_hidden_layers=1, num_attention_heads=1, attention_dropout=0.0, - intermediate_size=32, + intermediate_size=8, layer_norm_eps=1e-06, ln_elementwise_affine=True, ) @@ -68,17 +68,17 @@ def get_dummy_components(self): torch.manual_seed(0) vqvae = VQModel( act_fn="silu", - block_out_channels=[32], + block_out_channels=[8], down_block_types=[ "DownEncoderBlock2D", ], in_channels=3, - latent_channels=32, - layers_per_block=2, - norm_num_groups=32, - num_vq_embeddings=32, + latent_channels=8, + layers_per_block=1, + norm_num_groups=8, + num_vq_embeddings=32, # reducing this to 16 or 8 -> RuntimeError: "cdist_cuda" not implemented for 'Half' out_channels=3, - sample_size=32, + sample_size=8, up_block_types=[ "UpDecoderBlock2D", ], @@ -89,14 +89,14 @@ def get_dummy_components(self): text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, - intermediate_size=64, + hidden_size=8, + intermediate_size=8, layer_norm_eps=1e-05, - num_attention_heads=8, - num_hidden_layers=3, + num_attention_heads=1, + num_hidden_layers=1, pad_token_id=1, vocab_size=1000, - projection_dim=32, + projection_dim=8, ) text_encoder = CLIPTextModelWithProjection(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") From 3fd31eef518b73ee592f82435f3d370a716ead4f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 30 Apr 2024 08:46:51 +0530 Subject: [PATCH 12/20] [Core] introduce _no_split_modules to `ModelMixin` (#6396) * introduce _no_split_modules. * unnecessary spaces. * remove unnecessary kwargs and style * fix: accelerate imports. * change to _determine_device_map * add the blocks that have residual connections. * add: CrossAttnUpBlock2D * add: testin * style * line-spaces * quality * add disk offload test without safetensors. * checking disk offloading percentages. * change model split * add: utility for checking multi-gpu requirement. * model parallelism test * splits. * splits. * splits * splits. * splits. * splits. * offload folder to test_disk_offload_with_safetensors * add _no_split_modules * fix-copies --- .../models/autoencoders/autoencoder_kl.py | 1 + src/diffusers/models/modeling_utils.py | 92 ++++++++++++- .../models/transformers/transformer_2d.py | 1 + .../models/unets/unet_2d_condition.py | 1 + .../versatile_diffusion/modeling_text_unet.py | 1 + tests/models/test_modeling_common.py | 128 ++++++++++++++++++ .../unets/test_models_unet_2d_condition.py | 2 + 7 files changed, 221 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index b286453de424..0b9b9d4d47e5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c1fdff8ab356..8d9f2d9e71fc 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -57,7 +57,8 @@ if is_accelerate_available(): import accelerate - from accelerate.utils import set_module_tensor_to_device + from accelerate import infer_auto_device_map + from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device from accelerate.utils.versions import is_torch_version @@ -99,6 +100,29 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].dtype +# Adapted from `transformers` (see modeling_utils.py) +def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype): + if isinstance(device_map, str): + no_split_modules = model._get_no_split_modules(device_map) + device_map_kwargs = {"no_split_module_classes": no_split_modules} + + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=torch_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + + device_map_kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + + return device_map + + def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. @@ -201,6 +225,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None def __init__(self): super().__init__() @@ -560,6 +585,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) + # change device_map into a map if we passed an int, a str or a torch.device + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + if low_cpu_mem_usage: + if device_map is not None and not is_torch_version(">=", "1.10"): + # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. + raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") + # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path @@ -582,10 +637,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P token=token, revision=revision, subfolder=subfolder, - device_map=device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) @@ -690,6 +741,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU + device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) try: accelerate.load_checkpoint_and_dispatch( model, @@ -881,6 +933,36 @@ def _find_mismatched_keys( return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + # Adapted from `transformers` modeling_utils.py + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, ModelMixin): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.children()) + return list(_no_split_modules) + @property def device(self) -> torch.device: """ diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 768fceb71ae5..6a2695b9e436 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 34327e1049c5..697730b359ff 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -161,6 +161,7 @@ class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] @register_to_config def __init__( diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 3c3bd526692d..c84caa1fad88 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -363,6 +363,7 @@ class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"] @register_to_config def __init__( diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f919ba10fbb7..d8a93d40c8bf 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -24,6 +24,7 @@ import numpy as np import requests_mock import torch +from accelerate.utils import compute_module_sizes from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available from requests.exceptions import HTTPError @@ -39,6 +40,7 @@ require_torch_2, require_torch_accelerator_with_training, require_torch_gpu, + require_torch_multi_gpu, run_test_in_subprocess, torch_device, ) @@ -200,6 +202,21 @@ class ModelTesterMixin: main_input_name = None # overwrite in model specific tester class base_precision = 1e-3 forward_requires_fresh_args = False + model_split_percents = [0.5, 0.7, 0.9] + + def check_device_map_is_respected(self, model, device_map): + for param_name, param in model.named_parameters(): + # Find device in device_map + while len(param_name) > 0 and param_name not in device_map: + param_name = ".".join(param_name.split(".")[:-1]) + if param_name not in device_map: + raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") + + param_device = device_map[param_name] + if param_device in ["cpu", "disk"]: + self.assertEqual(param.device, torch.device("meta")) + else: + self.assertEqual(param.device, torch.device(param_device)) def test_from_save_pretrained(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: @@ -670,6 +687,117 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) + @require_torch_gpu + def test_cpu_offload(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + @require_torch_gpu + def test_disk_offload_without_safetensors(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, safe_serialization=False) + + with self.assertRaises(ValueError): + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + # This errors out because it's missing an offload folder + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = self.model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + @require_torch_gpu + def test_disk_offload_with_safetensors(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + max_size = int(self.model_split_percents[1] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = self.model_class.from_pretrained( + tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + @require_torch_multi_gpu + def test_model_parallelism(self): + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 1b8a998cfd66..33aa6a10377b 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -300,6 +300,8 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel main_input_name = "sample" + # We override the items here because the unet under consideration is small. + model_split_percents = [0.5, 0.3, 0.4] @property def dummy_input(self): From 26a7851e1e0b18da746d6ae80bb105050f7187e0 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 30 Apr 2024 07:16:30 +0300 Subject: [PATCH 13/20] Add B-Lora training option to the advanced dreambooth lora script (#7741) * add blora * add blora * add blora * add blora * little changes * little changes * remove redundancies * fixes * add B LoRA to readme * style * inference * defaults + path to loras+ generation * minor changes * style * minor changes * minor changes * blora arg * added --lora_unet_blocks * style * Update examples/advanced_diffusion_training/README.md Co-authored-by: Sayak Paul * add commit hash to B-LoRA repo cloneing * change inference, remove cloning * change inference, remove cloning add section about configureable unet blocks * change inference, remove cloning add section about configureable unet blocks * Apply suggestions from code review --------- Co-authored-by: Sayak Paul --- .../advanced_diffusion_training/README.md | 143 +++++++++++++++++- .../train_dreambooth_lora_sdxl_advanced.py | 107 +++++++++++-- 2 files changed, 236 insertions(+), 14 deletions(-) diff --git a/examples/advanced_diffusion_training/README.md b/examples/advanced_diffusion_training/README.md index fda73f9ce7a5..a13ae719cfdc 100644 --- a/examples/advanced_diffusion_training/README.md +++ b/examples/advanced_diffusion_training/README.md @@ -234,7 +234,7 @@ In ComfyUI we will load a LoRA and a textual embedding at the same time. SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). ### DoRA training -The advanced script now supports DoRA training too! +The advanced script supports DoRA training too! > Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353), **DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters. The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference. @@ -304,6 +304,147 @@ accelerate launch train_dreambooth_lora_sdxl_advanced.py \ > [!CAUTION] > Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant". +### B-LoRA training +The advanced script now supports B-LoRA training too! +> Proposed in [Implicit Style-Content Separation using B-LoRA](https://arxiv.org/abs/2403.14572), +B-LoRA is a method that leverages LoRA to implicitly separate the style and content components of a **single** image. +It was shown that learning the LoRA weights of two specific blocks (referred to as B-LoRAs) +achieves style-content separation that cannot be achieved by training each B-LoRA independently. +Once trained, the two B-LoRAs can be used as independent components to allow various image stylization tasks + +**Usage** +Enable B-LoRA training by adding this flag +```bash +--use_blora +``` +You can train a B-LoRA with as little as 1 image, and 1000 steps. Try this default configuration as a start: +```bash +!accelerate launch train_dreambooth_b-lora_sdxl.py \ + --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \ + --instance_data_dir="linoyts/B-LoRA_teddy_bear" \ + --output_dir="B-LoRA_teddy_bear" \ + --instance_prompt="a [v18]" \ + --resolution=1024 \ + --rank=64 \ + --train_batch_size=1 \ + --learning_rate=5e-5 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --gradient_checkpointing \ + --mixed_precision="fp16" +``` +**Inference** +The inference is a bit different: +1. we need load *specific* unet layers (as opposed to a regular LoRA/DoRA) +2. the trained layers we load, changes based on our objective (e.g. style/content) + +```python +import torch +from diffusers import StableDiffusionXLPipeline, AutoencoderKL + +# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py +def is_belong_to_blocks(key, blocks): + try: + for g in blocks: + if g in key: + return True + return False + except Exception as e: + raise type(e)(f'failed to is_belong_to_block, due to: {e}') + +def lora_lora_unet_blocks(lora_path, alpha, target_blocks): + state_dict, _ = pipeline.lora_state_dict(lora_path) + filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)} + return filtered_state_dict + +vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) +pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + vae=vae, + torch_dtype=torch.float16, +).to("cuda") + +# pick a blora for content/style (you can also set one to None) +content_B_lora_path = "lora-library/B-LoRA-teddybear" +style_B_lora_path= "lora-library/B-LoRA-pen_sketch" + + +content_B_LoRA = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"]) +style_B_LoRA = lora_lora_unet_blocks(style_B_lora_path,alpha=1.1,target_blocks=["unet.up_blocks.0.attentions.1"]) +combined_lora = {**content_B_LoRA, **style_B_LoRA} + +# Load both loras +pipeline.load_lora_into_unet(combined_lora, None, pipeline.unet) + +#generate +prompt = "a [v18] in [v30] style" +pipeline(prompt, num_images_per_prompt=4).images +``` +### LoRA training of Targeted U-net Blocks +The advanced script now supports custom choice of U-net blocks to train during Dreambooth LoRA tuning. +> [!NOTE] +> This feature is still experimental + +> Recently, works like B-LoRA showed the potential advantages of learning the LoRA weights of specific U-net blocks, not only in speed & memory, +> but also in reducing the amount of needed data, improving style manipulation and overcoming overfitting issues. +> In light of this, we're introducing a new feature to the advanced script to allow for configurable U-net learned blocks. + +**Usage** +Configure LoRA learned U-net blocks adding a `lora_unet_blocks` flag, with a comma seperated string specifying the targeted blocks. +e.g: +```bash +--lora_unet_blocks="unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1" +``` + +> [!NOTE] +> if you specify both `--use_blora` and `--lora_unet_blocks`, values given in --lora_unet_blocks will be ignored. +> When enabling --use_blora, targeted U-net blocks are automatically set to be "unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1" as discussed in the paper. +> If you wish to experiment with different blocks, specify `--lora_unet_blocks` only. + +**Inference** +Inference is the same as for B-LoRAs, except the input targeted blocks should be modified based on your training configuration. +```python +import torch +from diffusers import StableDiffusionXLPipeline, AutoencoderKL + +# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py +def is_belong_to_blocks(key, blocks): + try: + for g in blocks: + if g in key: + return True + return False + except Exception as e: + raise type(e)(f'failed to is_belong_to_block, due to: {e}') + +def lora_lora_unet_blocks(lora_path, alpha, target_blocks): + state_dict, _ = pipeline.lora_state_dict(lora_path) + filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)} + return filtered_state_dict + +vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) +pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + vae=vae, + torch_dtype=torch.float16, +).to("cuda") + +lora_path = "lora-library/B-LoRA-pen_sketch" + +state_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"]) + +# Load traine dlora layers into the unet +pipeline.load_lora_into_unet(state_dict, None, pipeline.unet) + +#generate +prompt = "a dog in [v30] style" +pipeline(prompt, num_images_per_prompt=4).images +``` + + ### Tips and Tricks Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 21a84b77245a..0699ac17077d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -15,7 +15,6 @@ import argparse import gc -import hashlib import itertools import json import logging @@ -40,6 +39,7 @@ from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, hf_hub_download, upload_folder +from huggingface_hub.utils import insecure_hashlib from packaging import version from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict @@ -696,6 +696,23 @@ def parse_args(input_args=None): "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" ), ) + parser.add_argument( + "--lora_unet_blocks", + type=str, + default=None, + help=( + "the U-net blocks to tune during training. please specify them in a comma separated string, e.g. `unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1` etc." + "NOTE: By default (if not specified) - regular LoRA training is performed. " + "if --use_blora is enabled, this arg will be ignored, since in B-LoRA training, targeted U-net blocks are `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`" + ), + ) + parser.add_argument( + "--use_blora", + action="store_true", + help=( + "Whether to train a B-LoRA as proposed in- Implicit Style-Content Separation using B-LoRA https://arxiv.org/abs/2403.14572. " + ), + ) parser.add_argument( "--cache_latents", action="store_true", @@ -720,6 +737,11 @@ def parse_args(input_args=None): "For full LoRA text encoder training check --train_text_encoder, for textual " "inversion training check `--train_text_encoder_ti`" ) + if args.use_blora and args.lora_unet_blocks: + warnings.warn( + "You specified both `--use_blora` and `--lora_unet_blocks`, for B-LoRA training, target unet blocks are: `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`. " + "If you wish to target different U-net blocks, don't enable `--use_blora`" + ) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -740,6 +762,40 @@ def parse_args(input_args=None): return args +# Taken (and slightly modified) from B-LoRA repo https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py +def is_belong_to_blocks(key, blocks): + try: + for g in blocks: + if g in key: + return True + return False + except Exception as e: + raise type(e)(f"failed to is_belong_to_block, due to: {e}") + + +def get_unet_lora_target_modules(unet, use_blora, target_blocks=None): + if use_blora: + content_b_lora_blocks = "unet.up_blocks.0.attentions.0" + style_b_lora_blocks = "unet.up_blocks.0.attentions.1" + target_blocks = [content_b_lora_blocks, style_b_lora_blocks] + try: + blocks = [(".").join(blk.split(".")[1:]) for blk in target_blocks] + + attns = [ + attn_processor_name.rsplit(".", 1)[0] + for attn_processor_name, _ in unet.attn_processors.items() + if is_belong_to_blocks(attn_processor_name, blocks) + ] + + target_modules = [f"{attn}.{mat}" for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns] + return target_modules + except Exception as e: + raise type(e)( + f"failed to get_target_modules, due to: {e}. " + f"Please check the modules specified in --lora_unet_blocks are correct" + ) + + # Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py class TokenEmbeddingsHandler: def __init__(self, text_encoders, tokenizers): @@ -946,16 +1002,20 @@ def __init__( transforms.Normalize([0.5], [0.5]), ] ) + # if using B-LoRA for single image. do not use transformations + single_image = len(self.instance_images) < 2 for image in self.instance_images: - image = exif_transpose(image) + if not single_image: + image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") self.original_sizes.append((image.height, image.width)) image = train_resize(image) - if args.random_flip and random.random() < 0.5: + + if not single_image and args.random_flip and random.random() < 0.5: # flip image = train_flip(image) - if args.center_crop: + if args.center_crop or single_image: y1 = max(0, int(round((image.height - args.resolution) / 2.0))) x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) @@ -1216,7 +1276,7 @@ def main(args): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -1374,12 +1434,24 @@ def main(args): text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers + + if args.use_blora: + # if using B-LoRA, the targeted blocks to train are automatically set + target_modules = get_unet_lora_target_modules(unet, use_blora=True) + elif args.lora_unet_blocks: + # if training specific unet blocks not in the B-LoRA scheme + target_blocks_list = "".join(args.lora_unet_blocks.split()).split(",") + logger.info(f"list of unet blocks to train: {target_blocks_list}") + target_modules = get_unet_lora_target_modules(unet, use_blora=False, target_blocks=target_blocks_list) + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + unet_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, use_dora=args.use_dora, + lora_alpha=args.rank, init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=target_modules, ) unet.add_adapter(unet_lora_config) @@ -1388,8 +1460,8 @@ def main(args): if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, use_dora=args.use_dora, + lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) @@ -1505,6 +1577,7 @@ def load_model_hook(models, input_dir): models = [unet_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) + # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1525,6 +1598,8 @@ def load_model_hook(models, input_dir): models = [unet] if args.train_text_encoder: models.extend([text_encoder_one, text_encoder_two]) + + # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) @@ -1780,7 +1855,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args)) + tracker_name = ( + "dreambooth-lora-sd-xl" + if "playground" not in args.pretrained_model_name_or_path + else "dreambooth-lora-playground" + ) + accelerator.init_trackers(tracker_name, config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1833,7 +1913,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - # TODO: revisit other sampling algorithms sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) @@ -1852,6 +1931,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # flag used for textual inversion pivoted = False for epoch in range(first_epoch, args.num_train_epochs): + unet.train() # if performing any kind of optimization of text_encoder params if args.train_text_encoder or args.train_text_encoder_ti: if epoch == num_train_epochs_text_encoder: @@ -1869,7 +1949,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one.text_model.embeddings.requires_grad_(True) text_encoder_two.text_model.embeddings.requires_grad_(True) - unet.train() for step, batch in enumerate(train_dataloader): if pivoted: # stopping optimization of text_encoder params @@ -1970,7 +2049,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, - ).sample + return_dict=False, + )[0] else: unet_added_conditions = {"time_ids": add_time_ids} prompt_embeds, pooled_prompt_embeds = encode_prompt( @@ -1988,7 +2068,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, - ).sample + return_dict=False, + )[0] weighting = None if args.do_edm_style_training: From 725ead2f5ec6f3f4beac66a5bddcee17647b9599 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 30 Apr 2024 20:14:18 +0530 Subject: [PATCH 14/20] SSH Runner Workflow Update (#7822) * add debug workflow * update --- .github/workflows/ssh-runner.yml | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ssh-runner.yml b/.github/workflows/ssh-runner.yml index befebfbc9b96..e5bbdd64f549 100644 --- a/.github/workflows/ssh-runner.yml +++ b/.github/workflows/ssh-runner.yml @@ -28,19 +28,10 @@ jobs: options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ steps: - - name: Update clone - working-directory: /diffusers - run: | - git fetch && git checkout ${{ github.sha }} - - name: Cleanup - working-directory: /diffusers - run: | - rm -rf tests/__pycache__ - rm -rf tests/models/__pycache__ - rm -rf reports - - name: Show installed libraries and their versions - working-directory: /diffusers - run: pip freeze + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 - name: NVIDIA-SMI run: | From b8ccb462596d336ce892e329ba69fa12394e9964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+standardAI@users.noreply.github.com> Date: Tue, 30 Apr 2024 20:53:27 +0300 Subject: [PATCH 15/20] Fix CPU offload in docstring (#7827) Fix cpu offload --- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index d7889a9efbb5..dfd3cc239b36 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -89,8 +89,8 @@ ... variant="fp16", ... use_safetensors=True, ... torch_dtype=torch.float16, - ... ).to("cuda") - >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda") + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) >>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", ... controlnet=controlnet, @@ -98,7 +98,7 @@ ... variant="fp16", ... use_safetensors=True, ... torch_dtype=torch.float16, - ... ).to("cuda") + ... ) >>> pipe.enable_model_cpu_offload() From 0d083702637ca61e7dd8533f6a3aa7558fce6d3b Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 30 Apr 2024 14:10:14 -0700 Subject: [PATCH 16/20] [docs] Community pipelines (#7819) * community pipelines * feedback * consolidate --- docs/source/en/_toctree.yml | 4 - docs/source/en/conceptual/contribution.md | 87 ++++++--- .../en/using-diffusers/contribute_pipeline.md | 184 ------------------ .../custom_pipeline_examples.md | 119 ----------- .../custom_pipeline_overview.md | 101 +++++++++- 5 files changed, 165 insertions(+), 330 deletions(-) delete mode 100644 docs/source/en/using-diffusers/contribute_pipeline.md delete mode 100644 docs/source/en/using-diffusers/custom_pipeline_examples.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 357afb2ea261..f2755798b792 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -87,10 +87,6 @@ title: Shap-E - local: using-diffusers/diffedit title: DiffEdit - - local: using-diffusers/custom_pipeline_examples - title: Community pipelines - - local: using-diffusers/contribute_pipeline - title: Contribute a community pipeline - local: using-diffusers/inference_with_lcm_lora title: Latent Consistency Model-LoRA - local: using-diffusers/inference_with_lcm diff --git a/docs/source/en/conceptual/contribution.md b/docs/source/en/conceptual/contribution.md index 24ac52ba19c9..cc2e0ae07b2c 100644 --- a/docs/source/en/conceptual/contribution.md +++ b/docs/source/en/conceptual/contribution.md @@ -198,38 +198,81 @@ Anything displayed on [the official Diffusers doc page](https://huggingface.co/d Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally. - ### 6. Contribute a community pipeline -[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user. -Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview). -We support two types of pipelines: +> [!TIP] +> Read the [Community pipelines](../using-diffusers/custom_pipeline_overview#community-pipelines) guide to learn more about the difference between a GitHub and Hugging Face Hub community pipeline. If you're interested in why we have community pipelines, take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) (basically, we can't maintain all the possible ways diffusion models can be used for inference but we also don't want to prevent the community from building them). + +Contributing a community pipeline is a great way to share your creativity and work with the community. It lets you build on top of the [`DiffusionPipeline`] so that anyone can load and use it by setting the `custom_pipeline` parameter. This section will walk you through how to create a simple pipeline where the UNet only does a single forward pass and calls the scheduler once (a "one-step" pipeline). + +1. Create a one_step_unet.py file for your community pipeline. This file can contain whatever package you want to use as long as it's installed by the user. Make sure you only have one pipeline class that inherits from [`DiffusionPipeline`] to load model weights and the scheduler configuration from the Hub. Add a UNet and scheduler to the `__init__` function. + + You should also add the `register_modules` function to ensure your pipeline and its components can be saved with [`~DiffusionPipeline.save_pretrained`]. + +```py +from diffusers import DiffusionPipeline +import torch + +class UnetSchedulerOneForwardPipeline(DiffusionPipeline): + def __init__(self, unet, scheduler): + super().__init__() + + self.register_modules(unet=unet, scheduler=scheduler) +``` + +1. In the forward pass (which we recommend defining as `__call__`), you can add any feature you'd like. For the "one-step" pipeline, create a random image and call the UNet and scheduler once by setting `timestep=1`. + +```py + from diffusers import DiffusionPipeline + import torch + + class UnetSchedulerOneForwardPipeline(DiffusionPipeline): + def __init__(self, unet, scheduler): + super().__init__() + + self.register_modules(unet=unet, scheduler=scheduler) -- Official Pipelines -- Community Pipelines + def __call__(self): + image = torch.randn( + (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + ) + timestep = 1 + + model_output = self.unet(image, timestep).sample + scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample + + return scheduler_output +``` -Both official and community pipelines follow the same design and consist of the same type of components. +Now you can run the pipeline by passing a UNet and scheduler to it or load pretrained weights if the pipeline structure is identical. + +```py +from diffusers import DDPMScheduler, UNet2DModel + +scheduler = DDPMScheduler() +unet = UNet2DModel() + +pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler) +output = pipeline() +# load pretrained weights +pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True) +output = pipeline() +``` -Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code -resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). -In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested. -They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution. +You can either share your pipeline as a GitHub community pipeline or Hub community pipeline. -The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all -possible ways diffusion models can be used for inference, but some of them may be of interest to the community. -Officially released diffusion pipelines, -such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures -high quality of maintenance, no backward-breaking code changes, and testing. -More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library. + + -To add a community pipeline, one should add a .py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline. +Share your GitHub pipeline by opening a pull request on the Diffusers [repository](https://github.com/huggingface/diffusers) and add the one_step_unet.py file to the [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) subfolder. -An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400). + + -Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors. +Share your Hub pipeline by creating a model repository on the Hub and uploading the one_step_unet.py file to it. -Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the -core package. + + ### 7. Contribute to training examples diff --git a/docs/source/en/using-diffusers/contribute_pipeline.md b/docs/source/en/using-diffusers/contribute_pipeline.md deleted file mode 100644 index e9cf1ed1ce02..000000000000 --- a/docs/source/en/using-diffusers/contribute_pipeline.md +++ /dev/null @@ -1,184 +0,0 @@ - - -# Contribute a community pipeline - - - -๐Ÿ’ก Take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down. - - - -Community pipelines allow you to add any additional features you'd like on top of the [`DiffusionPipeline`]. The main benefit of building on top of the `DiffusionPipeline` is anyone can load and use your pipeline by only adding one more argument, making it super easy for the community to access. - -This guide will show you how to create a community pipeline and explain how they work. To keep things simple, you'll create a "one-step" pipeline where the `UNet` does a single forward pass and calls the scheduler once. - -## Initialize the pipeline - -You should start by creating a `one_step_unet.py` file for your community pipeline. In this file, create a pipeline class that inherits from the [`DiffusionPipeline`] to be able to load model weights and the scheduler configuration from the Hub. The one-step pipeline needs a `UNet` and a scheduler, so you'll need to add these as arguments to the `__init__` function: - -```python -from diffusers import DiffusionPipeline -import torch - -class UnetSchedulerOneForwardPipeline(DiffusionPipeline): - def __init__(self, unet, scheduler): - super().__init__() -``` - -To ensure your pipeline and its components (`unet` and `scheduler`) can be saved with [`~DiffusionPipeline.save_pretrained`], add them to the `register_modules` function: - -```diff - from diffusers import DiffusionPipeline - import torch - - class UnetSchedulerOneForwardPipeline(DiffusionPipeline): - def __init__(self, unet, scheduler): - super().__init__() - -+ self.register_modules(unet=unet, scheduler=scheduler) -``` - -Cool, the `__init__` step is done and you can move to the forward pass now! ๐Ÿ”ฅ - -## Define the forward pass - -In the forward pass, which we recommend defining as `__call__`, you have complete creative freedom to add whatever feature you'd like. For our amazing one-step pipeline, create a random image and only call the `unet` and `scheduler` once by setting `timestep=1`: - -```diff - from diffusers import DiffusionPipeline - import torch - - class UnetSchedulerOneForwardPipeline(DiffusionPipeline): - def __init__(self, unet, scheduler): - super().__init__() - - self.register_modules(unet=unet, scheduler=scheduler) - -+ def __call__(self): -+ image = torch.randn( -+ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), -+ ) -+ timestep = 1 - -+ model_output = self.unet(image, timestep).sample -+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample - -+ return scheduler_output -``` - -That's it! ๐Ÿš€ You can now run this pipeline by passing a `unet` and `scheduler` to it: - -```python -from diffusers import DDPMScheduler, UNet2DModel - -scheduler = DDPMScheduler() -unet = UNet2DModel() - -pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler) - -output = pipeline() -``` - -But what's even better is you can load pre-existing weights into the pipeline if the pipeline structure is identical. For example, you can load the [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32) weights into the one-step pipeline: - -```python -pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True) - -output = pipeline() -``` - -## Share your pipeline - -Open a Pull Request on the ๐Ÿงจ Diffusers [repository](https://github.com/huggingface/diffusers) to add your awesome pipeline in `one_step_unet.py` to the [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) subfolder. - -Once it is merged, anyone with `diffusers >= 0.4.0` installed can use this pipeline magically ๐Ÿช„ by specifying it in the `custom_pipeline` argument: - -```python -from diffusers import DiffusionPipeline - -pipe = DiffusionPipeline.from_pretrained( - "google/ddpm-cifar10-32", custom_pipeline="one_step_unet", use_safetensors=True -) -pipe() -``` - -Another way to share your community pipeline is to upload the `one_step_unet.py` file directly to your preferred [model repository](https://huggingface.co/docs/hub/models-uploading) on the Hub. Instead of specifying the `one_step_unet.py` file, pass the model repository id to the `custom_pipeline` argument: - -```python -from diffusers import DiffusionPipeline - -pipeline = DiffusionPipeline.from_pretrained( - "google/ddpm-cifar10-32", custom_pipeline="stevhliu/one_step_unet", use_safetensors=True -) -``` - -Take a look at the following table to compare the two sharing workflows to help you decide the best option for you: - -| | GitHub community pipeline | HF Hub community pipeline | -|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------| -| usage | same | same | -| review process | open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging; may be slower | upload directly to a Hub repository without any review; this is the fastest workflow | -| visibility | included in the official Diffusers repository and documentation | included on your HF Hub profile and relies on your own usage/promotion to gain visibility | - - - -๐Ÿ’ก You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from `DiffusionPipeline` because this is automatically detected. - - - -## How do community pipelines work? - -A community pipeline is a class that inherits from [`DiffusionPipeline`] which means: - -- It can be loaded with the [`custom_pipeline`] argument. -- The model weights and scheduler configuration are loaded from [`pretrained_model_name_or_path`]. -- The code that implements a feature in the community pipeline is defined in a `pipeline.py` file. - -Sometimes you can't load all the pipeline components weights from an official repository. In this case, the other components should be passed directly to the pipeline: - -```python -from diffusers import DiffusionPipeline -from transformers import CLIPImageProcessor, CLIPModel - -model_id = "CompVis/stable-diffusion-v1-4" -clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" - -feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id) -clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16) - -pipeline = DiffusionPipeline.from_pretrained( - model_id, - custom_pipeline="clip_guided_stable_diffusion", - clip_model=clip_model, - feature_extractor=feature_extractor, - scheduler=scheduler, - torch_dtype=torch.float16, - use_safetensors=True, -) -``` - -The magic behind community pipelines is contained in the following code. It allows the community pipeline to be loaded from GitHub or the Hub, and it'll be available to all ๐Ÿงจ Diffusers packages. - -```python -# 2. Load the pipeline class, if using custom module then load it from the Hub -# if we load from explicit class, let's use it -if custom_pipeline is not None: - pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline - ) -elif cls != DiffusionPipeline: - pipeline_class = cls -else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) -``` diff --git a/docs/source/en/using-diffusers/custom_pipeline_examples.md b/docs/source/en/using-diffusers/custom_pipeline_examples.md deleted file mode 100644 index 203302ed3ead..000000000000 --- a/docs/source/en/using-diffusers/custom_pipeline_examples.md +++ /dev/null @@ -1,119 +0,0 @@ - - -# Community pipelines - -[[open-in-colab]] - - - -For more context about the design choices behind community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841). - - - -Community pipelines allow you to get creative and build your own unique pipelines to share with the community. You can find all community pipelines in the [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) folder along with inference and training examples for how to use them. This guide showcases some of the community pipelines and hopefully it'll inspire you to create your own (feel free to open a PR with your own pipeline and we will merge it!). - -To load a community pipeline, use the `custom_pipeline` argument in [`DiffusionPipeline`] to specify one of the files in [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community): - -```py -from diffusers import DiffusionPipeline - -pipe = DiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", custom_pipeline="filename_in_the_community_folder", use_safetensors=True -) -``` - -If a community pipeline doesn't work as expected, please open a GitHub issue and mention the author. - -You can learn more about community pipelines in the how to [load community pipelines](custom_pipeline_overview) and how to [contribute a community pipeline](contribute_pipeline) guides. - -## Multilingual Stable Diffusion - -The multilingual Stable Diffusion pipeline uses a pretrained [XLM-RoBERTa](https://huggingface.co/papluca/xlm-roberta-base-language-detection) to identify a language and the [mBART-large-50](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) model to handle the translation. This allows you to generate images from text in 20 languages. - -```py -import torch -from diffusers import DiffusionPipeline -from diffusers.utils import make_image_grid -from transformers import ( - pipeline, - MBart50TokenizerFast, - MBartForConditionalGeneration, -) - -device = "cuda" if torch.cuda.is_available() else "cpu" -device_dict = {"cuda": 0, "cpu": -1} - -# add language detection pipeline -language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection" -language_detection_pipeline = pipeline("text-classification", - model=language_detection_model_ckpt, - device=device_dict[device]) - -# add model for language translation -translation_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt") -translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device) - -diffuser_pipeline = DiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - custom_pipeline="multilingual_stable_diffusion", - detection_pipeline=language_detection_pipeline, - translation_model=translation_model, - translation_tokenizer=translation_tokenizer, - torch_dtype=torch.float16, -) - -diffuser_pipeline.enable_attention_slicing() -diffuser_pipeline = diffuser_pipeline.to(device) - -prompt = ["a photograph of an astronaut riding a horse", - "Una casa en la playa", - "Ein Hund, der Orange isst", - "Un restaurant parisien"] - -images = diffuser_pipeline(prompt).images -make_image_grid(images, rows=2, cols=2) -``` - -
- -
- -## MagicMix - -[MagicMix](https://huggingface.co/papers/2210.16056) is a pipeline that can mix an image and text prompt to generate a new image that preserves the image structure. The `mix_factor` determines how much influence the prompt has on the layout generation, `kmin` controls the number of steps during the content generation process, and `kmax` determines how much information is kept in the layout of the original image. - -```py -from diffusers import DiffusionPipeline, DDIMScheduler -from diffusers.utils import load_image, make_image_grid - -pipeline = DiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - custom_pipeline="magic_mix", - scheduler=DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), -).to('cuda') - -img = load_image("https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg") -mix_img = pipeline(img, prompt="bed", kmin=0.3, kmax=0.5, mix_factor=0.5) -make_image_grid([img, mix_img], rows=1, cols=2) -``` - -
-
- -
original image
-
-
- -
image and text prompt mix
-
-
diff --git a/docs/source/en/using-diffusers/custom_pipeline_overview.md b/docs/source/en/using-diffusers/custom_pipeline_overview.md index 0b6bb53f10d6..ef26e546e4d4 100644 --- a/docs/source/en/using-diffusers/custom_pipeline_overview.md +++ b/docs/source/en/using-diffusers/custom_pipeline_overview.md @@ -16,11 +16,19 @@ specific language governing permissions and limitations under the License. ## Community pipelines +> [!TIP] Take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down. + Community pipelines are any [`DiffusionPipeline`] class that are different from the original paper implementation (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline. There are many cool community pipelines like [Marigold Depth Estimation](https://github.com/huggingface/diffusers/tree/main/examples/community#marigold-depth-estimation) or [InstantID](https://github.com/huggingface/diffusers/tree/main/examples/community#instantid-pipeline), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community). -There are two types of community pipelines, those stored on the Hugging Face Hub and those stored on Diffusers GitHub repository. Hub pipelines are completely customizable (scheduler, models, pipeline code, etc.) while Diffusers GitHub pipelines are only limited to custom pipeline code. Refer to this [table](./contribute_pipeline#share-your-pipeline) for a more detailed comparison of Hub vs GitHub community pipelines. +There are two types of community pipelines, those stored on the Hugging Face Hub and those stored on Diffusers GitHub repository. Hub pipelines are completely customizable (scheduler, models, pipeline code, etc.) while Diffusers GitHub pipelines are only limited to custom pipeline code. + +| | GitHub community pipeline | HF Hub community pipeline | +|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------| +| usage | same | same | +| review process | open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging; may be slower | upload directly to a Hub repository without any review; this is the fastest workflow | +| visibility | included in the official Diffusers repository and documentation | included on your HF Hub profile and relies on your own usage/promotion to gain visibility | @@ -161,6 +169,97 @@ out_lpw +## Example community pipelines + +Community pipelines are a really fun and creative way to extend the capabilities of the original pipeline with new and unique features. You can find all community pipelines in the [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) folder with inference and training examples for how to use them. + +This section showcases a couple of the community pipelines and hopefully it'll inspire you to create your own (feel free to open a PR for your community pipeline and ping us for a review)! + +> [!TIP] +> The [`~DiffusionPipeline.from_pipe`] method is particularly useful for loading community pipelines because many of them don't have pretrained weights and add a feature on top of an existing pipeline like Stable Diffusion or Stable Diffusion XL. You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Load with from_pipe](custom_pipeline_overview#load-with-from_pipe) section. + + + + +[Marigold](https://marigoldmonodepth.github.io/) is a depth estimation diffusion pipeline that uses the rich existing and inherent visual knowledge in diffusion models. It takes an input image and denoises and decodes it into a depth map. Marigold performs well even on images it hasn't seen before. + +```py +import torch +from PIL import Image +from diffusers import DiffusionPipeline +from diffusers.utils import load_image + +pipeline = DiffusionPipeline.from_pretrained( + "prs-eth/marigold-lcm-v1-0", + custom_pipeline="marigold_depth_estimation", + torch_dtype=torch.float16, + variant="fp16", +) + +pipeline.to("cuda") +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/community-marigold.png") +output = pipeline( + image, + denoising_steps=4, + ensemble_size=5, + processing_res=768, + match_input_res=True, + batch_size=0, + seed=33, + color_map="Spectral", + show_progress_bar=True, +) +depth_colored: Image.Image = output.depth_colored +depth_colored.save("./depth_colored.png") +``` + +
+
+ +
original image
+
+
+ +
colorized depth image
+
+
+ +
+ + +[HD-Painter](https://hf.co/papers/2312.14091) is a high-resolution inpainting pipeline. It introduces a *Prompt-Aware Introverted Attention (PAIntA)* layer to better align a prompt with the area to be inpainted, and *Reweighting Attention Score Guidance (RASG)* to keep the latents more prompt-aligned and within their trained domain to generate realistc images. + +```py +import torch +from diffusers import DiffusionPipeline, DDIMScheduler +from diffusers.utils import load_image + +pipeline = DiffusionPipeline.from_pretrained( + "Lykon/dreamshaper-8-inpainting", + custom_pipeline="hd_painter" +) +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hd-painter.jpg") +mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hd-painter-mask.png") +prompt = "football" +image = pipeline(prompt, init_image, mask_image, use_rasg=True, use_painta=True, generator=torch.manual_seed(0)).images[0] +image +``` + +
+
+ +
original image
+
+
+ +
generated image
+
+
+ +
+
+ ## Community components Community components allow users to build pipelines that may have customized components that are not a part of Diffusers. If your pipeline has custom components that Diffusers doesn't already support, you need to provide their implementations as Python modules. These customized components could be a VAE, UNet, and scheduler. In most cases, the text encoder is imported from the Transformers library. The pipeline code itself can also be customized. From c1edb03c372c12f92b1b5580b718e8cf2196016c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 May 2024 17:36:54 +0530 Subject: [PATCH 17/20] Fix for pipeline slow test fetcher (#7824) * update * update --- .github/workflows/nightly_tests.yml | 44 ++++++++++++++--------------- .github/workflows/push_tests.yml | 2 +- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 2f73c66de829..d911dab4a306 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -19,7 +19,7 @@ env: jobs: setup_torch_cuda_pipeline_matrix: name: Setup Torch Pipelines Matrix - runs-on: ubuntu-latest + runs-on: diffusers/diffusers-pytorch-cpu outputs: pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }} steps: @@ -67,19 +67,19 @@ jobs: fetch-depth: 2 - name: NVIDIA-SMI run: nvidia-smi - + - name: Install dependencies run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git python -m uv pip install pytest-reportlog - + - name: Environment run: | python utils/print_env.py - - - name: Nightly PyTorch CUDA checkpoint (pipelines) tests + + - name: Nightly PyTorch CUDA checkpoint (pipelines) tests env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms @@ -88,9 +88,9 @@ jobs: python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \ - --report-log=tests_pipeline_${{ matrix.module }}_cuda.log \ + --report-log=tests_pipeline_${{ matrix.module }}_cuda.log \ tests/pipelines/${{ matrix.module }} - + - name: Failure short reports if: ${{ failure() }} run: | @@ -103,7 +103,7 @@ jobs: with: name: pipeline_${{ matrix.module }}_test_reports path: reports - + - name: Generate Report and Notify Channel if: always() run: | @@ -139,7 +139,7 @@ jobs: run: python utils/print_env.py - name: Run nightly PyTorch CUDA tests for non-pipeline modules - if: ${{ matrix.module != 'examples'}} + if: ${{ matrix.module != 'examples'}} env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms @@ -148,7 +148,7 @@ jobs: python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ --make-reports=tests_torch_${{ matrix.module }}_cuda \ - --report-log=tests_torch_${{ matrix.module }}_cuda.log \ + --report-log=tests_torch_${{ matrix.module }}_cuda.log \ tests/${{ matrix.module }} - name: Run nightly example tests with Torch @@ -161,13 +161,13 @@ jobs: python -m uv pip install peft@git+https://github.com/huggingface/peft.git python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v --make-reports=examples_torch_cuda \ - --report-log=examples_torch_cuda.log \ + --report-log=examples_torch_cuda.log \ examples/ - name: Failure short reports if: ${{ failure() }} run: | - cat reports/tests_torch_${{ matrix.module }}_cuda_stats.txt + cat reports/tests_torch_${{ matrix.module }}_cuda_stats.txt cat reports/tests_torch_${{ matrix.module }}_cuda_failures_short.txt - name: Test suite reports artifacts @@ -218,13 +218,13 @@ jobs: python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ --make-reports=tests_torch_lora_cuda \ - --report-log=tests_torch_lora_cuda.log \ + --report-log=tests_torch_lora_cuda.log \ tests/lora - + - name: Failure short reports if: ${{ failure() }} run: | - cat reports/tests_torch_lora_cuda_stats.txt + cat reports/tests_torch_lora_cuda_stats.txt cat reports/tests_torch_lora_cuda_failures_short.txt - name: Test suite reports artifacts @@ -239,12 +239,12 @@ jobs: run: | pip install slack_sdk tabulate python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY - + run_flax_tpu_tests: name: Nightly Flax TPU Tests runs-on: docker-tpu if: github.event_name == 'schedule' - + container: image: diffusers/diffusers-flax-tpu options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged @@ -274,7 +274,7 @@ jobs: python -m pytest -n 0 \ -s -v -k "Flax" \ --make-reports=tests_flax_tpu \ - --report-log=tests_flax_tpu.log \ + --report-log=tests_flax_tpu.log \ tests/ - name: Failure short reports @@ -302,7 +302,7 @@ jobs: container: image: diffusers/diffusers-onnxruntime-cuda options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ - + steps: - name: Checkout diffusers uses: actions/checkout@v3 @@ -321,7 +321,7 @@ jobs: - name: Environment run: python utils/print_env.py - + - name: Run nightly ONNXRuntime CUDA tests env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} @@ -329,7 +329,7 @@ jobs: python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "Onnx" \ --make-reports=tests_onnx_cuda \ - --report-log=tests_onnx_cuda.log \ + --report-log=tests_onnx_cuda.log \ tests/ - name: Failure short reports @@ -344,7 +344,7 @@ jobs: with: name: ${{ matrix.config.report }}_test_reports path: reports - + - name: Generate Report and Notify Channel if: always() run: | diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index a6cb123a7035..d5e1c4739497 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -21,7 +21,7 @@ env: jobs: setup_torch_cuda_pipeline_matrix: name: Setup Torch Pipelines CUDA Slow Tests Matrix - runs-on: ubuntu-latest + runs-on: diffusers/diffusers-pytorch-cpu outputs: pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }} steps: From 8909ab4b192500bcc3c17d839ae101cc669e9d8e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 1 May 2024 18:45:47 +0530 Subject: [PATCH 18/20] [Tests] fix: device map tests for models (#7825) * fix: device module tests * remove patch file * Empty-Commit --- tests/models/test_modeling_common.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index d8a93d40c8bf..d9e70c6dd784 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -691,6 +691,9 @@ def test_deprecated_kwargs(self): def test_cpu_offload(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() + if model._no_split_modules is None: + return + model = model.to(torch_device) torch.manual_seed(0) @@ -718,6 +721,9 @@ def test_cpu_offload(self): def test_disk_offload_without_safetensors(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() + if model._no_split_modules is None: + return + model = model.to(torch_device) torch.manual_seed(0) @@ -728,12 +734,12 @@ def test_disk_offload_without_safetensors(self): model.cpu().save_pretrained(tmp_dir, safe_serialization=False) with self.assertRaises(ValueError): - max_size = int(self.model_split_percents[1] * model_size) + max_size = int(self.model_split_percents[0] * model_size) max_memory = {0: max_size, "cpu": max_size} # This errors out because it's missing an offload folder new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) - max_size = int(self.model_split_percents[1] * model_size) + max_size = int(self.model_split_percents[0] * model_size) max_memory = {0: max_size, "cpu": max_size} new_model = self.model_class.from_pretrained( tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir @@ -749,6 +755,9 @@ def test_disk_offload_without_safetensors(self): def test_disk_offload_with_safetensors(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() + if model._no_split_modules is None: + return + model = model.to(torch_device) torch.manual_seed(0) @@ -758,7 +767,7 @@ def test_disk_offload_with_safetensors(self): with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) - max_size = int(self.model_split_percents[1] * model_size) + max_size = int(self.model_split_percents[0] * model_size) max_memory = {0: max_size, "cpu": max_size} new_model = self.model_class.from_pretrained( tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory @@ -774,6 +783,9 @@ def test_disk_offload_with_safetensors(self): def test_model_parallelism(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() + if model._no_split_modules is None: + return + model = model.to(torch_device) torch.manual_seed(0) From 21a7ff12a75ecf43a85898838d1990cda853ffaf Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 1 May 2024 06:25:57 -1000 Subject: [PATCH 19/20] update the logic of `is_sequential_cpu_offload` (#7788) * up * add comment to the tests + fix dit --------- Co-authored-by: Sayak Paul --- .../community/pipeline_demofusion_sdxl.py | 6 +- src/diffusers/loaders/lora.py | 6 +- src/diffusers/loaders/textual_inversion.py | 6 +- src/diffusers/loaders/unet.py | 6 +- src/diffusers/pipelines/dit/pipeline_dit.py | 3 + src/diffusers/pipelines/pipeline_utils.py | 9 +- tests/pipelines/pixart_alpha/test_pixart.py | 4 - tests/pipelines/pixart_sigma/test_pixart.py | 4 - tests/pipelines/test_pipelines_common.py | 103 ++++++++++++++++-- 9 files changed, 123 insertions(+), 24 deletions(-) diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py index 93e1463638f0..f46d635dae2b 100644 --- a/examples/community/pipeline_demofusion_sdxl.py +++ b/examples/community/pipeline_demofusion_sdxl.py @@ -1304,7 +1304,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if isinstance(component, torch.nn.Module): if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 8703cdee4011..d69db5a83af1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -369,7 +369,11 @@ def _optionally_disable_offloading(cls, _pipeline): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: - is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index c1c224975cb8..05ed64f5dcad 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -423,7 +423,11 @@ def load_textual_inversion( if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." ) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 294db44ee61d..3e74411865a3 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -359,7 +359,11 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index 289ea496028d..a3ea90874a12 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -227,6 +227,9 @@ def __call__( if output_type == "pil": samples = self.numpy_to_pil(samples) + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (samples,) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 68433332546b..59e38c910d4a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -376,7 +376,11 @@ def module_is_sequentially_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + return hasattr(module, "_hf_hook") and ( + isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + or hasattr(module._hf_hook, "hooks") + and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) + ) def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): @@ -1005,8 +1009,7 @@ def remove_all_hooks(self): """ for _, model in self.components.items(): if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): - is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook) - accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload) + accelerate.hooks.remove_hook_from_module(model, recurse=True) self._all_hooks = [] def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index d981b55260c7..dd358af08395 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -324,10 +324,6 @@ def test_raises_warning_for_mask_feature(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) - # PixArt transformer model does not work with sequential offload so skip it for now - def test_sequential_offload_forward_pass_twice(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index 7b1d5e389f32..c0df15ae661d 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -308,10 +308,6 @@ def test_inference_with_multiple_images_per_prompt(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) - # PixArt transformer model does not work with sequential offload so skip it for now - def test_sequential_offload_forward_pass_twice(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0c0a765f662d..032fbb81ea31 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1360,6 +1360,8 @@ def _test_attention_slicing_forward_pass( reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", ) def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): + import accelerate + components = self.get_dummy_components() pipe = self.pipeline_class(**components) for component in pipe.components.values(): @@ -1373,6 +1375,7 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): output_without_offload = pipe(**inputs)[0] pipe.enable_sequential_cpu_offload() + assert pipe._execution_device.type == pipe._offload_device.type inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] @@ -1380,11 +1383,48 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly + offloaded_modules = { + k: v + for k, v in pipe.components.items() + if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload + } + # 1. all offloaded modules should be saved to cpu and moved to meta device + self.assertTrue( + all(v.device.type == "meta" for v in offloaded_modules.values()), + f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}", + ) + # 2. all offloaded modules should have hook installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. all offloaded modules should have correct hooks installed, should be either one of these two + # - `AlignDevicesHook` + # - a SequentialHook` that contains `AlignDevicesHook` + offloaded_modules_with_incorrect_hooks = {} + for k, v in offloaded_modules.items(): + if hasattr(v, "_hf_hook"): + if isinstance(v._hf_hook, accelerate.hooks.SequentialHook): + # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook` + for hook in v._hf_hook.hooks: + if not isinstance(hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0]) + elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) + + self.assertTrue( + len(offloaded_modules_with_incorrect_hooks) == 0, + f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", + ) + @unittest.skipIf( torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", ) def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): + import accelerate + generator_device = "cpu" components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -1400,19 +1440,39 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): output_without_offload = pipe(**inputs)[0] pipe.enable_model_cpu_offload() + assert pipe._execution_device.type == pipe._offload_device.type + inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") - offloaded_modules = [ - v + + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly + offloaded_modules = { + k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload - ] - ( - self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)), - f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", + } + # 1. check if all offloaded modules are saved to cpu + self.assertTrue( + all(v.device.type == "cpu" for v in offloaded_modules.values()), + f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}", + ) + # 2. check if all offloaded modules have hooks installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload` + offloaded_modules_with_incorrect_hooks = {} + for k, v in offloaded_modules.items(): + if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) + + self.assertTrue( + len(offloaded_modules_with_incorrect_hooks) == 0, + f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", ) @unittest.skipIf( @@ -1444,16 +1504,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): self.assertLess( max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results" ) + + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly offloaded_modules = { k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload } + # 1. check if all offloaded modules are saved to cpu self.assertTrue( all(v.device.type == "cpu" for v in offloaded_modules.values()), f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}", ) - + # 2. check if all offloaded modules have hooks installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload` offloaded_modules_with_incorrect_hooks = {} for k, v in offloaded_modules.items(): if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload): @@ -1493,19 +1561,36 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4): self.assertLess( max_diff, expected_max_diff, "running sequential offloading second time should have the inference results" ) + + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly offloaded_modules = { k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload } + # 1. check if all offloaded modules are moved to meta device self.assertTrue( all(v.device.type == "meta" for v in offloaded_modules.values()), f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}", ) + # 2. check if all offloaded modules have hook installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. check if all offloaded modules have correct hooks installed, should be either one of these two + # - `AlignDevicesHook` + # - a SequentialHook` that contains `AlignDevicesHook` offloaded_modules_with_incorrect_hooks = {} for k, v in offloaded_modules.items(): - if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook): - offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) + if hasattr(v, "_hf_hook"): + if isinstance(v._hf_hook, accelerate.hooks.SequentialHook): + # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook` + for hook in v._hf_hook.hooks: + if not isinstance(hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0]) + elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) self.assertTrue( len(offloaded_modules_with_incorrect_hooks) == 0, From 5915c2985db162278e09196160d796166c89ad12 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 1 May 2024 06:27:43 -1000 Subject: [PATCH 20/20] [ip-adapter] fix ip-adapter for StableDiffusionInstructPix2PixPipeline (#7820) update prepare_ip_adapter_ for pix2pix --- ...eline_stable_diffusion_instruct_pix2pix.py | 95 +++++++++++++++++-- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 1 - 2 files changed, 87 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index de2767e23952..0bf5a92a4fcc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -172,6 +172,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, @@ -296,6 +297,8 @@ def __call__( negative_prompt, prompt_embeds, negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale @@ -303,14 +306,6 @@ def __call__( device = self._execution_device - if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state - ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds]) - if image is None: raise ValueError("`image` input cannot be undefined.") @@ -335,6 +330,14 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) # 3. Preprocess image image = self.image_processor.preprocess(image) @@ -635,6 +638,65 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat( + [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] + ) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + ( + single_image_embeds, + single_negative_image_embeds, + single_negative_image_embeds, + ) = single_image_embeds.chunk(3) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat( + [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] + ) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -687,6 +749,8 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): @@ -728,6 +792,21 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index a2242bb099c5..5e7be370be01 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -436,7 +436,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs def check_inputs( self, prompt,