diff --git a/src/diffusers/schedulers/scheduling_lcm_flax.py b/src/diffusers/schedulers/scheduling_lcm_flax.py index f45ca18ef190..51a8c0711eb3 100644 --- a/src/diffusers/schedulers/scheduling_lcm_flax.py +++ b/src/diffusers/schedulers/scheduling_lcm_flax.py @@ -2,8 +2,9 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union -import flax +import jax import jax.numpy as jnp +import flax from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging @@ -341,17 +342,14 @@ def step( (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] - # 1. get previous step value + # 1. get previous step value prev_step_index = step_index + 1 - if prev_step_index < len(state.timesteps): - prev_timestep = state.timesteps[prev_step_index] - else: - prev_timestep = timestep + prev_timestep = jax.lax.select(prev_step_index < len(state.timesteps), state.timesteps[prev_step_index], timestep) # 2. compute alphas, betas - alpha_prod_t = state.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - + alpha_prod_t = state.common.alphas_cumprod[timestep] + alpha_prod_t_prev = jax.lax.select(prev_timestep >=0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -360,15 +358,14 @@ def step( # 4. Compute the predicted original sample x_0 based on the model parameterization if self.config.prediction_type == "epsilon": # noise-prediction - predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + predicted_original_sample = (sample - jnp.sqrt(beta_prod_t) * model_output) / jnp.sqrt(alpha_prod_t) elif self.config.prediction_type == "sample": # x-prediction predicted_original_sample = model_output elif self.config.prediction_type == "v_prediction": # v-prediction - predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + predicted_original_sample = jnp.sqrt(alpha_prod_t) * sample - jnp.sqrt(beta_prod_t.sqrt) * model_output else: raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" - " `v_prediction` for `LCMScheduler`." + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or `v_prediction` for `LCMScheduler`." ) # 5. Clip or threshold "predicted x_0" @@ -400,7 +397,7 @@ def step( if not return_dict: return (prev_sample, state) - return FlaxLCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + return FlaxLCMSchedulerOutput(prev_sample=prev_sample, state=state) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise(