From 87e24cbb30985499b5292c32a29acaadb9592b6e Mon Sep 17 00:00:00 2001 From: manuelbrack <manuel.brack@stud.tu-darmstadt.de> Date: Thu, 14 Mar 2024 10:13:41 +0100 Subject: [PATCH] template commit for LEdits++ refactoring --- .../pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py | 2 ++ .../ledits_pp/pipeline_leditspp_stable_diffusion_xl.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py index a6357c4cd3a1..01c5920b9a86 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py @@ -844,6 +844,7 @@ def __call__( org_prompt = "" + # TODO: Check LEdits++ inputs + verify that invert has been run properly # 1. Check inputs. Raise error if not correct self.check_inputs( negative_prompt, @@ -933,6 +934,7 @@ def __call__( # 7. Denoising loop num_warmup_steps = 0 + # TODO: Refactor out SEGA/LEdits++ functionality with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py index 874a10a7ccd5..705c11a2af4b 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -974,7 +974,7 @@ def __call__( if user_mask is not None: user_mask = user_mask.to(self.device) - # TODO: Check inputs + # TODO: Check LEdits++ inputs + verify that invert has been run properly # 1. Check inputs. Raise error if not correct # self.check_inputs( # callback_steps, @@ -1115,6 +1115,7 @@ def __call__( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) + # TODO: Refactor out SEGA/LEdits++ functionality self._num_timesteps = len(timesteps) with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps):