From c022e52923c897d0c16ebcaa9feecfbf1dfbec66 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Dec 2023 13:35:21 +0530 Subject: [PATCH 1/6] Remove ONNX inpaint legacy (#6269) update Co-authored-by: Sayak Paul --- ...st_onnx_stable_diffusion_inpaint_legacy.py | 97 ------------------- 1 file changed, 97 deletions(-) delete mode 100644 tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py deleted file mode 100644 index 235aa32f7338..000000000000 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py +++ /dev/null @@ -1,97 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np - -from diffusers import OnnxStableDiffusionInpaintPipelineLegacy -from diffusers.utils.testing_utils import ( - is_onnx_available, - load_image, - load_numpy, - nightly, - require_onnxruntime, - require_torch_gpu, -) - - -if is_onnx_available(): - import onnxruntime as ort - - -@nightly -@require_onnxruntime -@require_torch_gpu -class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase): - @property - def gpu_provider(self): - return ( - "CUDAExecutionProvider", - { - "gpu_mem_limit": "15000000000", # 15GB - "arena_extend_strategy": "kSameAsRequested", - }, - ) - - @property - def gpu_options(self): - options = ort.SessionOptions() - options.enable_mem_pattern = False - return options - - def test_inference(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ) - mask_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ) - expected_image = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/red_cat_sitting_on_a_park_bench_onnx.npy" - ) - - # using the PNDM scheduler by default - pipe = OnnxStableDiffusionInpaintPipelineLegacy.from_pretrained( - "CompVis/stable-diffusion-v1-4", - revision="onnx", - safety_checker=None, - feature_extractor=None, - provider=self.gpu_provider, - sess_options=self.gpu_options, - ) - pipe.set_progress_bar_config(disable=None) - - prompt = "A red cat sitting on a park bench" - - generator = np.random.RandomState(0) - output = pipe( - prompt=prompt, - image=init_image, - mask_image=mask_image, - strength=0.75, - guidance_scale=7.5, - num_inference_steps=15, - generator=generator, - output_type="np", - ) - - image = output.images[0] - - assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-2 From 59d1caa2385fae1761232e22ed6bcf4f9a492bf7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Dec 2023 13:35:52 +0530 Subject: [PATCH 2/6] Remove peft tests from old lora backend tests (#6273) update --- tests/lora/test_lora_layers_peft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 1d8c6977440c..f6cd2a714ae2 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -1397,7 +1397,7 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): @slow @require_torch_gpu -class LoraIntegrationTests(unittest.TestCase): +class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): def tearDown(self): import gc @@ -1650,7 +1650,7 @@ def test_load_unload_load_kohya_lora(self): @slow @require_torch_gpu -class LoraSDXLIntegrationTests(unittest.TestCase): +class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): def tearDown(self): import gc From 7fe47596af40ea900318e6a2a773a00ff3f3a115 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 22 Dec 2023 09:37:30 +0100 Subject: [PATCH 3/6] Allow diffusers to load with Flax, w/o PyTorch (#6272) --- src/diffusers/utils/torch_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 00bc75f41be3..d0d02fb92e72 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -89,7 +89,7 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) -def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor: +def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). This version of the method comes from here: @@ -121,8 +121,8 @@ def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tens def apply_freeu( - resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: + resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs +) -> Tuple["torch.Tensor", "torch.Tensor"]: """Applies the FreeU mechanism as introduced in https: //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. From 3369bc810a09a52521bbf8cc1ec77df3a8c682a8 Mon Sep 17 00:00:00 2001 From: Bingxin Ke <45253439+markkua@users.noreply.github.com> Date: Fri, 22 Dec 2023 11:11:46 +0100 Subject: [PATCH 4/6] [Community Pipeline] Add Marigold Monocular Depth Estimation (#6249) * [Community Pipeline] Add Marigold Monocular Depth Estimation - add single-file pipeline - update README * fix format - add one blank line * format script with ruff * use direct image link in example code --------- Co-authored-by: Sayak Paul --- examples/community/README.md | 48 ++ .../community/marigold_depth_estimation.py | 602 ++++++++++++++++++ 2 files changed, 650 insertions(+) create mode 100644 examples/community/marigold_depth_estimation.py diff --git a/examples/community/README.md b/examples/community/README.md index 7af6d1d7eb02..c3aa1ecf3d64 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -8,6 +8,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | Example | Description | Code Example | Colab | Author | |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| +| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) | | LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) | | CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) | | One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | @@ -61,6 +62,53 @@ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custo ## Example usages +### Marigold Depth Estimation + +Marigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers). + +![Marigold Teaser](https://marigoldmonodepth.github.io/images/teaser_collage_compressed.jpg) + +This depth estimation pipeline processes a single input image through multiple diffusion denoising stages to estimate depth maps. These maps are subsequently merged to produce the final output. Below is an example code snippet, including optional arguments: + +```python +import numpy as np +from PIL import Image +from diffusers import DiffusionPipeline +from diffusers.utils import load_image + +pipe = DiffusionPipeline.from_pretrained( + "Bingxin/Marigold", + custom_pipeline="marigold_depth_estimation" + # torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float). +) + +pipe.to("cuda") + +img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_example.jpg" +image: Image.Image = load_image(img_path_or_url) + +pipeline_output = pipe( + image, # Input image. + # denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10. + # ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10. + # processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768. + # match_input_res=True, # (optional) Resize depth prediction to match input resolution. + # batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0. + # color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". + # show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress. +) + +depth: np.ndarray = pipeline_output.depth_np # Predicted depth map +depth_colored: Image.Image = pipeline_output.depth_colored # Colorized prediction + +# Save as uint16 PNG +depth_uint16 = (depth * 65535.0).astype(np.uint16) +Image.fromarray(depth_uint16).save("./depth_map.png", mode="I;16") + +# Save colorized depth map +depth_colored.save("./depth_colored.png") +``` + ### LLM-grounded Diffusion LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py new file mode 100644 index 000000000000..31da842112fb --- /dev/null +++ b/examples/community/marigold_depth_estimation.py @@ -0,0 +1,602 @@ +# Copyright 2023 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +import math +from typing import Dict, Union + +import matplotlib +import numpy as np +import torch +from PIL import Image +from scipy.optimize import minimize +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.utils import BaseOutput, check_min_version + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.20.1.dev0") + + +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold monocular depth prediction pipeline. + + Args: + depth_np (`np.ndarray`): + Predicted depth map, with depth values in the range of [0, 1]. + depth_colored (`PIL.Image.Image`): + Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. + uncertainty (`None` or `np.ndarray`): + Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. + """ + + depth_np: np.ndarray + depth_colored: Image.Image + uncertainty: Union[None, np.ndarray] + + +class MarigoldPipeline(DiffusionPipeline): + """ + Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the depth latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps + to and from latent representations. + scheduler (`DDIMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + """ + + rgb_latent_scale_factor = 0.18215 + depth_latent_scale_factor = 0.18215 + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: DDIMScheduler, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + self.empty_text_embed = None + + @torch.no_grad() + def __call__( + self, + input_image: Image, + denoising_steps: int = 10, + ensemble_size: int = 10, + processing_res: int = 768, + match_input_res: bool = True, + batch_size: int = 0, + color_map: str = "Spectral", + show_progress_bar: bool = True, + ensemble_kwargs: Dict = None, + ) -> MarigoldDepthOutput: + """ + Function invoked when calling the pipeline. + + Args: + input_image (`Image`): + Input RGB (or gray-scale) image. + processing_res (`int`, *optional*, defaults to `768`): + Maximum resolution of processing. + If set to 0: will not resize at all. + match_input_res (`bool`, *optional*, defaults to `True`): + Resize depth prediction to match input resolution. + Only valid if `limit_input_res` is not None. + denoising_steps (`int`, *optional*, defaults to `10`): + Number of diffusion denoising steps (DDIM) during inference. + ensemble_size (`int`, *optional*, defaults to `10`): + Number of predictions to be ensembled. + batch_size (`int`, *optional*, defaults to `0`): + Inference batch size, no bigger than `num_ensemble`. + If set to 0, the script will automatically decide the proper batch size. + show_progress_bar (`bool`, *optional*, defaults to `True`): + Display a progress bar of diffusion denoising. + color_map (`str`, *optional*, defaults to `"Spectral"`): + Colormap used to colorize the depth map. + ensemble_kwargs (`dict`, *optional*, defaults to `None`): + Arguments for detailed ensembling settings. + Returns: + `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: + - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] + - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1] + - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) + coming from ensembling. None if `ensemble_size = 1` + """ + + device = self.device + input_size = input_image.size + + if not match_input_res: + assert processing_res is not None, "Value error: `resize_output_back` is only valid with " + assert processing_res >= 0 + assert denoising_steps >= 1 + assert ensemble_size >= 1 + + # ----------------- Image Preprocess ----------------- + # Resize image + if processing_res > 0: + input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res) + # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel + input_image = input_image.convert("RGB") + image = np.asarray(input_image) + + # Normalize rgb values + rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] + rgb_norm = rgb / 255.0 + rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) + rgb_norm = rgb_norm.to(device) + assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 + + # ----------------- Predicting depth ----------------- + # Batch repeated input image + duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) + single_rgb_dataset = TensorDataset(duplicated_rgb) + if batch_size > 0: + _bs = batch_size + else: + _bs = self._find_batch_size( + ensemble_size=ensemble_size, + input_res=max(rgb_norm.shape[1:]), + dtype=self.dtype, + ) + + single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False) + + # Predict depth maps (batched) + depth_pred_ls = [] + if show_progress_bar: + iterable = tqdm(single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False) + else: + iterable = single_rgb_loader + for batch in iterable: + (batched_img,) = batch + depth_pred_raw = self.single_infer( + rgb_in=batched_img, + num_inference_steps=denoising_steps, + show_pbar=show_progress_bar, + ) + depth_pred_ls.append(depth_pred_raw.detach().clone()) + depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() + torch.cuda.empty_cache() # clear vram cache for ensembling + + # ----------------- Test-time ensembling ----------------- + if ensemble_size > 1: + depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {})) + else: + depth_pred = depth_preds + pred_uncert = None + + # ----------------- Post processing ----------------- + # Scale prediction to [0, 1] + min_d = torch.min(depth_pred) + max_d = torch.max(depth_pred) + depth_pred = (depth_pred - min_d) / (max_d - min_d) + + # Convert to numpy + depth_pred = depth_pred.cpu().numpy().astype(np.float32) + + # Resize back to original resolution + if match_input_res: + pred_img = Image.fromarray(depth_pred) + pred_img = pred_img.resize(input_size) + depth_pred = np.asarray(pred_img) + + # Clip output range + depth_pred = depth_pred.clip(0, 1) + + # Colorize + depth_colored = self.colorize_depth_maps( + depth_pred, 0, 1, cmap=color_map + ).squeeze() # [3, H, W], value in (0, 1) + depth_colored = (depth_colored * 255).astype(np.uint8) + depth_colored_hwc = self.chw2hwc(depth_colored) + depth_colored_img = Image.fromarray(depth_colored_hwc) + return MarigoldDepthOutput( + depth_np=depth_pred, + depth_colored=depth_colored_img, + uncertainty=pred_uncert, + ) + + def _encode_empty_text(self): + """ + Encode text embedding for empty prompt. + """ + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) + self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) + + @torch.no_grad() + def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor: + """ + Perform an individual depth prediction without ensembling. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image. + num_inference_steps (`int`): + Number of diffusion denoisign steps (DDIM) during inference. + show_pbar (`bool`): + Display a progress bar of diffusion denoising. + Returns: + `torch.Tensor`: Predicted depth map. + """ + device = rgb_in.device + + # Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # [T] + + # Encode image + rgb_latent = self._encode_rgb(rgb_in) + + # Initial depth map (noise) + depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w] + + # Batched empty text embedding + if self.empty_text_embed is None: + self._encode_empty_text() + batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1)) # [B, 2, 1024] + + # Denoising loop + if show_pbar: + iterable = tqdm( + enumerate(timesteps), + total=len(timesteps), + leave=False, + desc=" " * 4 + "Diffusion denoising", + ) + else: + iterable = enumerate(timesteps) + + for i, t in iterable: + unet_input = torch.cat([rgb_latent, depth_latent], dim=1) # this order is important + + # predict the noise residual + noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w] + + # compute the previous noisy sample x_t -> x_t-1 + depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample + torch.cuda.empty_cache() + depth = self._decode_depth(depth_latent) + + # clip prediction + depth = torch.clip(depth, -1.0, 1.0) + # shift to [0, 1] + depth = (depth + 1.0) / 2.0 + + return depth + + def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: + """ + Encode RGB image into latent. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image to be encoded. + + Returns: + `torch.Tensor`: Image latent. + """ + # encode + h = self.vae.encoder(rgb_in) + moments = self.vae.quant_conv(h) + mean, logvar = torch.chunk(moments, 2, dim=1) + # scale latent + rgb_latent = mean * self.rgb_latent_scale_factor + return rgb_latent + + def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: + """ + Decode depth latent into depth map. + + Args: + depth_latent (`torch.Tensor`): + Depth latent to be decoded. + + Returns: + `torch.Tensor`: Decoded depth map. + """ + # scale latent + depth_latent = depth_latent / self.depth_latent_scale_factor + # decode + z = self.vae.post_quant_conv(depth_latent) + stacked = self.vae.decoder(z) + # mean of output channels + depth_mean = stacked.mean(dim=1, keepdim=True) + return depth_mean + + @staticmethod + def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: + """ + Resize image to limit maximum edge length while keeping aspect ratio. + + Args: + img (`Image.Image`): + Image to be resized. + max_edge_resolution (`int`): + Maximum edge length (pixel). + + Returns: + `Image.Image`: Resized image. + """ + original_width, original_height = img.size + downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height) + + new_width = int(original_width * downscale_factor) + new_height = int(original_height * downscale_factor) + + resized_img = img.resize((new_width, new_height)) + return resized_img + + @staticmethod + def colorize_depth_maps(depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None): + """ + Colorize depth maps. + """ + assert len(depth_map.shape) >= 2, "Invalid dimension" + + if isinstance(depth_map, torch.Tensor): + depth = depth_map.detach().clone().squeeze().numpy() + elif isinstance(depth_map, np.ndarray): + depth = depth_map.copy().squeeze() + # reshape to [ (B,) H, W ] + if depth.ndim < 3: + depth = depth[np.newaxis, :, :] + + # colorize + cm = matplotlib.colormaps[cmap] + depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) + img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 + img_colored_np = np.rollaxis(img_colored_np, 3, 1) + + if valid_mask is not None: + if isinstance(depth_map, torch.Tensor): + valid_mask = valid_mask.detach().numpy() + valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] + if valid_mask.ndim < 3: + valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] + else: + valid_mask = valid_mask[:, np.newaxis, :, :] + valid_mask = np.repeat(valid_mask, 3, axis=1) + img_colored_np[~valid_mask] = 0 + + if isinstance(depth_map, torch.Tensor): + img_colored = torch.from_numpy(img_colored_np).float() + elif isinstance(depth_map, np.ndarray): + img_colored = img_colored_np + + return img_colored + + @staticmethod + def chw2hwc(chw): + assert 3 == len(chw.shape) + if isinstance(chw, torch.Tensor): + hwc = torch.permute(chw, (1, 2, 0)) + elif isinstance(chw, np.ndarray): + hwc = np.moveaxis(chw, 0, -1) + return hwc + + @staticmethod + def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: + """ + Automatically search for suitable operating batch size. + + Args: + ensemble_size (`int`): + Number of predictions to be ensembled. + input_res (`int`): + Operating resolution of the input image. + + Returns: + `int`: Operating batch size. + """ + # Search table for suggested max. inference batch size + bs_search_table = [ + # tested on A100-PCIE-80GB + {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, + {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, + # tested on A100-PCIE-40GB + {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, + {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, + {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, + {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, + # tested on RTX3090, RTX4090 + {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, + {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, + {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, + {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, + {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, + {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, + # tested on GTX1080Ti + {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, + {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, + {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, + {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, + {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, + ] + + if not torch.cuda.is_available(): + return 1 + + total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 + filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] + for settings in sorted( + filtered_bs_search_table, + key=lambda k: (k["res"], -k["total_vram"]), + ): + if input_res <= settings["res"] and total_vram >= settings["total_vram"]: + bs = settings["bs"] + if bs > ensemble_size: + bs = ensemble_size + elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: + bs = math.ceil(ensemble_size / 2) + return bs + + return 1 + + @staticmethod + def ensemble_depths( + input_images: torch.Tensor, + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + reduction: str = "median", + max_res: int = None, + ): + """ + To ensemble multiple affine-invariant depth images (up to scale and shift), + by aligning estimating the scale and shift + """ + + def inter_distances(tensors: torch.Tensor): + """ + To calculate the distance between each two depth maps. + """ + distances = [] + for i, j in torch.combinations(torch.arange(tensors.shape[0])): + arr1 = tensors[i : i + 1] + arr2 = tensors[j : j + 1] + distances.append(arr1 - arr2) + dist = torch.concatenate(distances, dim=0) + return dist + + device = input_images.device + dtype = input_images.dtype + np_dtype = np.float32 + + original_input = input_images.clone() + n_img = input_images.shape[0] + ori_shape = input_images.shape + + if max_res is not None: + scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") + input_images = downscaler(torch.from_numpy(input_images)).numpy() + + # init guess + _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) + t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) + x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) + + input_images = input_images.to(device) + + # objective function + def closure(x): + l = len(x) + s = x[: int(l / 2)] + t = x[int(l / 2) :] + s = torch.from_numpy(s).to(dtype=dtype).to(device) + t = torch.from_numpy(t).to(dtype=dtype).to(device) + + transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) + dists = inter_distances(transformed_arrays) + sqrt_dist = torch.sqrt(torch.mean(dists**2)) + + if "mean" == reduction: + pred = torch.mean(transformed_arrays, dim=0) + elif "median" == reduction: + pred = torch.median(transformed_arrays, dim=0).values + else: + raise ValueError + + near_err = torch.sqrt((0 - torch.min(pred)) ** 2) + far_err = torch.sqrt((1 - torch.max(pred)) ** 2) + + err = sqrt_dist + (near_err + far_err) * regularizer_strength + err = err.detach().cpu().numpy().astype(np_dtype) + return err + + res = minimize( + closure, + x, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) + x = res.x + l = len(x) + s = x[: int(l / 2)] + t = x[int(l / 2) :] + + # Prediction + s = torch.from_numpy(s).to(dtype=dtype).to(device) + t = torch.from_numpy(t).to(dtype=dtype).to(device) + transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) + if "mean" == reduction: + aligned_images = torch.mean(transformed_arrays, dim=0) + std = torch.std(transformed_arrays, dim=0) + uncertainty = std + elif "median" == reduction: + aligned_images = torch.median(transformed_arrays, dim=0).values + # MAD (median absolute deviation) as uncertainty indicator + abs_dev = torch.abs(transformed_arrays - aligned_images) + mad = torch.median(abs_dev, dim=0).values + uncertainty = mad + else: + raise ValueError(f"Unknown reduction method: {reduction}") + + # Scale and shift to [0, 1] + _min = torch.min(aligned_images) + _max = torch.max(aligned_images) + aligned_images = (aligned_images - _min) / (_max - _min) + uncertainty /= _max - _min + + return aligned_images, uncertainty From df76a39e1bc1de5bec647ce56a7fe4d8d1b6a643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Fri, 22 Dec 2023 06:42:04 -0600 Subject: [PATCH 5/6] Fix Prodigy optimizer in SDXL Dreambooth script (#6290) * Fix ProdigyOPT in SDXL Dreambooth script * style * style --- .../dreambooth/train_dreambooth_lora_sdxl.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9992292e30aa..8a3ac294fef2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1144,10 +1144,26 @@ def load_model_hook(models, input_dir): optimizer_class = prodigyopt.Prodigy + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, From 90b9479903dcf3b053dc2461d4d6266eed0c27ea Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 24 Dec 2023 09:59:41 +0530 Subject: [PATCH 6/6] [LoRA PEFT] fix LoRA loading so that correct alphas are parsed (#6225) * initialize alpha too. * add: test * remove config parsing * store rank * debug * remove faulty test --- examples/dreambooth/train_dreambooth_lora.py | 6 +++++- examples/dreambooth/train_dreambooth_lora_sdxl.py | 10 ++++++++-- examples/text_to_image/train_text_to_image_lora.py | 5 ++++- .../text_to_image/train_text_to_image_lora_sdxl.py | 10 ++++++++-- tests/lora/test_lora_layers_peft.py | 8 ++++++-- 5 files changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 55ef2bbeb8eb..67132d6d88df 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -827,6 +827,7 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( r=args.rank, + lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], ) @@ -835,7 +836,10 @@ def main(args): # The text encoder comes from 🤗 transformers, we will also attach adapters to it. if args.train_text_encoder: text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder.add_adapter(text_lora_config) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 8a3ac294fef2..0f41ad47d1ac 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -978,7 +978,10 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) @@ -986,7 +989,10 @@ def main(args): # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c8efbddd0b44..d6d0dee0883c 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -452,7 +452,10 @@ def main(): param.requires_grad_(False) unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) # Move unet, vae and text_encoder to device and cast to weight_dtype diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index be17c13c2885..d95fcbbba033 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -609,7 +609,10 @@ def main(args): # now we will add new LoRA weights to the attention layers # Set correct lora layers unet_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) @@ -618,7 +621,10 @@ def main(args): if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 text_lora_config = LoraConfig( - r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index f6cd2a714ae2..30125f64f6ac 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -111,6 +111,7 @@ class PeftLoraLoaderMixinTests: def get_dummy_components(self, scheduler_cls=None): scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler + rank = 4 torch.manual_seed(0) unet = UNet2DConditionModel(**self.unet_kwargs) @@ -125,11 +126,14 @@ def get_dummy_components(self, scheduler_cls=None): tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False + r=rank, + lora_alpha=rank, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + init_lora_weights=False, ) unet_lora_config = LoraConfig( - r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False ) unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)