From 87e24cbb30985499b5292c32a29acaadb9592b6e Mon Sep 17 00:00:00 2001
From: manuelbrack <manuel.brack@stud.tu-darmstadt.de>
Date: Thu, 14 Mar 2024 10:13:41 +0100
Subject: [PATCH] template commit for LEdits++ refactoring

---
 .../pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py  | 2 ++
 .../ledits_pp/pipeline_leditspp_stable_diffusion_xl.py         | 3 ++-
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index a6357c4cd3a1..01c5920b9a86 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -844,6 +844,7 @@ def __call__(
 
         org_prompt = ""
 
+        # TODO: Check LEdits++ inputs + verify that invert has been run properly
         # 1. Check inputs. Raise error if not correct
         self.check_inputs(
             negative_prompt,
@@ -933,6 +934,7 @@ def __call__(
 
         # 7. Denoising loop
         num_warmup_steps = 0
+        # TODO: Refactor out SEGA/LEdits++ functionality
         with self.progress_bar(total=len(timesteps)) as progress_bar:
             for i, t in enumerate(timesteps):
                 # expand the latents if we are doing classifier free guidance
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
index 874a10a7ccd5..705c11a2af4b 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
@@ -974,7 +974,7 @@ def __call__(
         if user_mask is not None:
             user_mask = user_mask.to(self.device)
 
-        # TODO: Check inputs
+        # TODO: Check LEdits++ inputs + verify that invert has been run properly
         # 1. Check inputs. Raise error if not correct
         # self.check_inputs(
         #    callback_steps,
@@ -1115,6 +1115,7 @@ def __call__(
                 guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
             ).to(device=device, dtype=latents.dtype)
 
+        # TODO: Refactor out SEGA/LEdits++ functionality
         self._num_timesteps = len(timesteps)
         with self.progress_bar(total=self._num_timesteps) as progress_bar:
             for i, t in enumerate(timesteps):