From cef0e3677e9a32fcf27cdfccfdf809f689a6f908 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:04:26 +0200 Subject: [PATCH] [RF inversion community pipeline] add eta_decay (#10199) * add decay * add decay * style --- examples/community/pipeline_flux_rf_inversion.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index f09160c4571d..c8a87a426dc0 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -648,6 +648,8 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, eta: float = 1.0, + decay_eta: Optional[bool] = False, + eta_decay_power: Optional[float] = 1.0, strength: float = 1.0, start_timestep: float = 0, stop_timestep: float = 0.25, @@ -880,12 +882,9 @@ def __call__( v_t = -noise_pred v_t_cond = (y_0 - latents) / (1 - t_i) eta_t = eta if start_timestep <= i < stop_timestep else 0.0 - if start_timestep <= i < stop_timestep: - # controlled vector field - v_hat_t = v_t + eta * (v_t_cond - v_t) - - else: - v_hat_t = v_t + if decay_eta: + eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop + v_hat_t = v_t + eta_t * (v_t_cond - v_t) # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])