From c46711e895bcec849fdcce69a7bc6864f023b6d1 Mon Sep 17 00:00:00 2001 From: Monohydroxides <75928535+Monohydroxides@users.noreply.github.com> Date: Thu, 14 Dec 2023 23:17:20 +0800 Subject: [PATCH] [Community] Add SDE Drag pipeline (#6105) * Add community pipeline: sde_drag.py * Update README.md * Update README.md Update example code and visual example * Update sde_drag.py Update code example. --- examples/community/README.md | 40 +++ examples/community/sde_drag.py | 594 +++++++++++++++++++++++++++++++++ 2 files changed, 634 insertions(+) create mode 100644 examples/community/sde_drag.py diff --git a/examples/community/README.md b/examples/community/README.md index e121e68bc9ce..c8865adf78f7 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -48,6 +48,7 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap | Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) | | Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) | | Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | +| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | - | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) | | Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) | | LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) | | AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) | @@ -2986,3 +2987,42 @@ def image_grid(imgs, save_path=None): image_grid(images, save_path="./outputs/") ``` ![output_example](https://github.com/PRIS-CV/DemoFusion/blob/main/output_example.png) + +### SDE Drag pipeline + +This pipeline provides drag-and-drop image editing using stochastic differential equations. It enables image editing by inputting prompt, image, mask_image, source_points, and target_points. + +![SDE Drag Image](https://github.com/huggingface/diffusers/assets/75928535/bd54f52f-f002-4951-9934-b2a4592771a5) + +See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more infomation. + +```py +import PIL +import torch +from diffusers import DDIMScheduler, DiffusionPipeline + +# Load the pipeline +model_path = "runwayml/stable-diffusion-v1-5" +scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler") +pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag") +pipe.to('cuda') + +# To save GPU memory, torch.float16 can be used, but it may compromise image quality. +# If not training LoRA, please avoid using torch.float16 +# pipe.to(torch.float16) + +# Provide prompt, image, mask image, and the starting and target points for drag editing. +prompt = "prompt of the image" +image = PIL.Image.open('/path/to/image') +mask_image = PIL.Image.open('/path/to/mask_image') +source_points = [[123, 456]] +target_points = [[234, 567]] + +# train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image. +pipe.train_lora(prompt, image) + +output = pipe(prompt, image, mask_image, source_points, target_points) +output_image = PIL.Image.fromarray(output) +output_image.save("./output.png") + +``` diff --git a/examples/community/sde_drag.py b/examples/community/sde_drag.py new file mode 100644 index 000000000000..08e865b9c350 --- /dev/null +++ b/examples/community/sde_drag.py @@ -0,0 +1,594 @@ +import math +import tempfile +from typing import List, Optional + +import numpy as np +import PIL.Image +import torch +from accelerate import Accelerator +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnAddedKVProcessor, +) +from diffusers.optimization import get_scheduler + + +class SdeDragPipeline(DiffusionPipeline): + r""" + Pipeline for image drag-and-drop editing using stochastic differential equations: https://arxiv.org/abs/2311.01410. + Please refer to the [official repository](https://github.com/ML-GSAI/SDE-Drag) for more information. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Please use + [`DDIMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + prompt: str, + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + source_points: List[List[int]], + target_points: List[List[int]], + t0: Optional[float] = 0.6, + steps: Optional[int] = 200, + step_size: Optional[int] = 2, + image_scale: Optional[float] = 0.3, + adapt_radius: Optional[int] = 5, + min_lora_scale: Optional[float] = 0.5, + generator: Optional[torch.Generator] = None, + ): + r""" + Function invoked when calling the pipeline for image editing. + Args: + prompt (`str`, *required*): + The prompt to guide the image editing. + image (`PIL.Image.Image`, *required*): + Which will be edited, parts of the image will be masked out with `mask_image` and edited + according to `prompt`. + mask_image (`PIL.Image.Image`, *required*): + To mask `image`. White pixels in the mask will be edited, while black pixels will be preserved. + source_points (`List[List[int]]`, *required*): + Used to mark the starting positions of drag editing in the image, with each pixel represented as a + `List[int]` of length 2. + target_points (`List[List[int]]`, *required*): + Used to mark the target positions of drag editing in the image, with each pixel represented as a + `List[int]` of length 2. + t0 (`float`, *optional*, defaults to 0.6): + The time parameter. Higher t0 improves the fidelity while lowering the faithfulness of the edited images + and vice versa. + steps (`int`, *optional*, defaults to 200): + The number of sampling iterations. + step_size (`int`, *optional*, defaults to 2): + The drag diatance of each drag step. + image_scale (`float`, *optional*, defaults to 0.3): + To avoid duplicating the content, use image_scale to perturbs the source. + adapt_radius (`int`, *optional*, defaults to 5): + The size of the region for copy and paste operations during each step of the drag process. + min_lora_scale (`float`, *optional*, defaults to 0.5): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + min_lora_scale specifies the minimum LoRA scale during the image drag-editing process. + generator ('torch.Generator', *optional*, defaults to None): + To make generation deterministic(https://pytorch.org/docs/stable/generated/torch.Generator.html). + Examples: + ```py + >>> import PIL + >>> import torch + >>> from diffusers import DDIMScheduler, DiffusionPipeline + + >>> # Load the pipeline + >>> model_path = "runwayml/stable-diffusion-v1-5" + >>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler") + >>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag") + >>> pipe.to('cuda') + + >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. + >>> # If not training LoRA, please avoid using torch.float16 + >>> # pipe.to(torch.float16) + + >>> # Provide prompt, image, mask image, and the starting and target points for drag editing. + >>> prompt = "prompt of the image" + >>> image = PIL.Image.open('/path/to/image') + >>> mask_image = PIL.Image.open('/path/to/mask_image') + >>> source_points = [[123, 456]] + >>> target_points = [[234, 567]] + + >>> # train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image. + >>> pipe.train_lora(prompt, image) + + >>> output = pipe(prompt, image, mask_image, source_points, target_points) + >>> output_image = PIL.Image.fromarray(output) + >>> output_image.save("./output.png") + ``` + """ + + self.scheduler.set_timesteps(steps) + + noise_scale = (1 - image_scale**2) ** (0.5) + + text_embeddings = self._get_text_embed(prompt) + uncond_embeddings = self._get_text_embed([""]) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + latent = self._get_img_latent(image) + + mask = mask_image.resize((latent.shape[3], latent.shape[2])) + mask = torch.tensor(np.array(mask)) + mask = mask.unsqueeze(0).expand_as(latent).to(self.device) + + source_points = torch.tensor(source_points).div(torch.tensor([8]), rounding_mode="trunc") + target_points = torch.tensor(target_points).div(torch.tensor([8]), rounding_mode="trunc") + + distance = target_points - source_points + distance_norm_max = torch.norm(distance.float(), dim=1, keepdim=True).max() + + if distance_norm_max <= step_size: + drag_num = 1 + else: + drag_num = distance_norm_max.div(torch.tensor([step_size]), rounding_mode="trunc") + if (distance_norm_max / drag_num - step_size).abs() > ( + distance_norm_max / (drag_num + 1) - step_size + ).abs(): + drag_num += 1 + + latents = [] + for i in tqdm(range(int(drag_num)), desc="SDE Drag"): + source_new = source_points + (i / drag_num * distance).to(torch.int) + target_new = source_points + ((i + 1) / drag_num * distance).to(torch.int) + + latent, noises, hook_latents, lora_scales, cfg_scales = self._forward( + latent, steps, t0, min_lora_scale, text_embeddings, generator + ) + latent = self._copy_and_paste( + latent, + source_new, + target_new, + adapt_radius, + latent.shape[2] - 1, + latent.shape[3] - 1, + image_scale, + noise_scale, + generator, + ) + latent = self._backward( + latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator + ) + + latents.append(latent) + + result_image = 1 / 0.18215 * latents[-1] + + with torch.no_grad(): + result_image = self.vae.decode(result_image).sample + + result_image = (result_image / 2 + 0.5).clamp(0, 1) + result_image = result_image.cpu().permute(0, 2, 3, 1).numpy()[0] + result_image = (result_image * 255).astype(np.uint8) + + return result_image + + def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None): + accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision="fp16") + + self.vae.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.unet.requires_grad_(False) + + unet_lora_attn_procs = {} + for name, attn_processor in self.unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.unet.config.block_out_channels[block_id] + else: + raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks") + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 + if hasattr(torch.nn.functional, "scaled_dot_product_attention") + else LoRAAttnProcessor + ) + unet_lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank + ) + + self.unet.set_attn_processor(unet_lora_attn_procs) + unet_lora_layers = AttnProcsLayers(self.unet.attn_processors) + params_to_optimize = unet_lora_layers.parameters() + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=2e-4, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + lr_scheduler = get_scheduler( + "constant", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=lora_step, + num_cycles=1, + power=1.0, + ) + + unet_lora_layers = accelerator.prepare_model(unet_lora_layers) + optimizer = accelerator.prepare_optimizer(optimizer) + lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) + + with torch.no_grad(): + text_inputs = self._tokenize_prompt(prompt, tokenizer_max_length=None) + text_embedding = self._encode_prompt( + text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=False + ) + + image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + image = image_transforms(image).to(self.device, dtype=self.vae.dtype) + image = image.unsqueeze(dim=0) + latents_dist = self.vae.encode(image).latent_dist + + for _ in tqdm(range(lora_step), desc="Train LoRA"): + self.unet.train() + model_input = latents_dist.sample() * self.vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn( + model_input.size(), + dtype=model_input.dtype, + layout=model_input.layout, + device=model_input.device, + generator=generator, + ) + bsz, channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_pred = self.unet(noisy_model_input, timesteps, text_embedding).sample + + # Get the target for loss depending on the prediction type + if self.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.scheduler.config.prediction_type == "v_prediction": + target = self.scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}") + + loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean") + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + with tempfile.TemporaryDirectory() as save_lora_dir: + LoraLoaderMixin.save_lora_weights( + save_directory=save_lora_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=None, + ) + + self.unet.load_attn_procs(save_lora_dir) + + def _tokenize_prompt(self, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = self.tokenizer.model_max_length + + text_inputs = self.tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + def _encode_prompt(self, input_ids, attention_mask, text_encoder_use_attention_mask=False): + text_input_ids = input_ids.to(self.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(self.device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + @torch.no_grad() + def _get_text_embed(self, prompt): + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + return text_embeddings + + def _copy_and_paste( + self, latent, source_new, target_new, adapt_radius, max_height, max_width, image_scale, noise_scale, generator + ): + def adaption_r(source, target, adapt_radius, max_height, max_width): + r_x_lower = min(adapt_radius, source[0], target[0]) + r_x_upper = min(adapt_radius, max_width - source[0], max_width - target[0]) + r_y_lower = min(adapt_radius, source[1], target[1]) + r_y_upper = min(adapt_radius, max_height - source[1], max_height - target[1]) + return r_x_lower, r_x_upper, r_y_lower, r_y_upper + + for source_, target_ in zip(source_new, target_new): + r_x_lower, r_x_upper, r_y_lower, r_y_upper = adaption_r( + source_, target_, adapt_radius, max_height, max_width + ) + + source_feature = latent[ + :, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper + ].clone() + + latent[ + :, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper + ] = image_scale * source_feature + noise_scale * torch.randn( + latent.shape[0], + 4, + r_y_lower + r_y_upper, + r_x_lower + r_x_upper, + device=self.device, + generator=generator, + ) + + latent[ + :, :, target_[1] - r_y_lower : target_[1] + r_y_upper, target_[0] - r_x_lower : target_[0] + r_x_upper + ] = source_feature * 1.1 + return latent + + @torch.no_grad() + def _get_img_latent(self, image, height=None, weight=None): + data = image.convert("RGB") + if height is not None: + data = data.resize((weight, height)) + transform = transforms.ToTensor() + data = transform(data).unsqueeze(0) + data = (data * 2.0) - 1.0 + data = data.to(self.device, dtype=self.vae.dtype) + latent = self.vae.encode(data).latent_dist.sample() + latent = 0.18215 * latent + return latent + + @torch.no_grad() + def _get_eps(self, latent, timestep, guidance_scale, text_embeddings, lora_scale=None): + latent_model_input = torch.cat([latent] * 2) if guidance_scale > 1.0 else latent + text_embeddings = text_embeddings if guidance_scale > 1.0 else text_embeddings.chunk(2)[1] + + cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale} + + with torch.no_grad(): + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if guidance_scale > 1.0: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + elif guidance_scale == 1.0: + noise_pred_text = noise_pred + noise_pred_uncond = 0.0 + else: + raise NotImplementedError(guidance_scale) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + return noise_pred + + def _forward_sde( + self, timestep, sample, guidance_scale, text_embeddings, steps, eta=1.0, lora_scale=None, generator=None + ): + num_train_timesteps = len(self.scheduler) + alphas_cumprod = self.scheduler.alphas_cumprod + initial_alpha_cumprod = torch.tensor(1.0) + + prev_timestep = timestep + num_train_timesteps // steps + + alpha_prod_t = alphas_cumprod[timestep] if timestep >= 0 else initial_alpha_cumprod + alpha_prod_t_prev = alphas_cumprod[prev_timestep] + + beta_prod_t_prev = 1 - alpha_prod_t_prev + + x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** ( + 0.5 + ) * torch.randn( + sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator + ) + eps = self._get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale) + + sigma_t_prev = ( + eta + * (1 - alpha_prod_t) ** (0.5) + * (1 - alpha_prod_t_prev / (1 - alpha_prod_t_prev) * (1 - alpha_prod_t) / alpha_prod_t) ** (0.5) + ) + + pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5) + pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev**2) ** (0.5) + + noise = ( + sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps + ) / sigma_t_prev + + return x_prev, noise + + def _sample( + self, + timestep, + sample, + guidance_scale, + text_embeddings, + steps, + sde=False, + noise=None, + eta=1.0, + lora_scale=None, + generator=None, + ): + num_train_timesteps = len(self.scheduler) + alphas_cumprod = self.scheduler.alphas_cumprod + final_alpha_cumprod = torch.tensor(1.0) + + eps = self._get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale) + + prev_timestep = timestep - num_train_timesteps // steps + + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + sigma_t = ( + eta + * ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5) + * (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5) + if sde + else 0 + ) + + pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5) + pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5) + + noise = ( + torch.randn( + sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator + ) + if noise is None + else noise + ) + latent = ( + alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise + ) + + return latent + + def _forward(self, latent, steps, t0, lora_scale_min, text_embeddings, generator): + def scale_schedule(begin, end, n, length, type="linear"): + if type == "constant": + return end + elif type == "linear": + return begin + (end - begin) * n / length + elif type == "cos": + factor = (1 - math.cos(n * math.pi / length)) / 2 + return (1 - factor) * begin + factor * end + else: + raise NotImplementedError(type) + + noises = [] + latents = [] + lora_scales = [] + cfg_scales = [] + latents.append(latent) + t0 = int(t0 * steps) + t_begin = steps - t0 + + length = len(self.scheduler.timesteps[t_begin - 1 : -1]) - 1 + index = 1 + for t in self.scheduler.timesteps[t_begin:].flip(dims=[0]): + lora_scale = scale_schedule(1, lora_scale_min, index, length, type="cos") + cfg_scale = scale_schedule(1, 3.0, index, length, type="linear") + latent, noise = self._forward_sde( + t, latent, cfg_scale, text_embeddings, steps, lora_scale=lora_scale, generator=generator + ) + + noises.append(noise) + latents.append(latent) + lora_scales.append(lora_scale) + cfg_scales.append(cfg_scale) + index += 1 + return latent, noises, latents, lora_scales, cfg_scales + + def _backward( + self, latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator + ): + t0 = int(t0 * steps) + t_begin = steps - t0 + + hook_latent = hook_latents.pop() + latent = torch.where(mask > 128, latent, hook_latent) + for t in self.scheduler.timesteps[t_begin - 1 : -1]: + latent = self._sample( + t, + latent, + cfg_scales.pop(), + text_embeddings, + steps, + sde=True, + noise=noises.pop(), + lora_scale=lora_scales.pop(), + generator=generator, + ) + hook_latent = hook_latents.pop() + latent = torch.where(mask > 128, latent, hook_latent) + return latent