Skip to content

Commit

Permalink
lcm_scheduler runs, but bad images
Browse files Browse the repository at this point in the history
  • Loading branch information
entrpn committed Dec 5, 2023
1 parent 111be88 commit a66445c
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions src/diffusers/schedulers/scheduling_lcm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp
import flax

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
Expand Down Expand Up @@ -341,17 +342,14 @@ def step(
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]

# 1. get previous step value
# 1. get previous step value
prev_step_index = step_index + 1
if prev_step_index < len(state.timesteps):
prev_timestep = state.timesteps[prev_step_index]
else:
prev_timestep = timestep
prev_timestep = jax.lax.select(prev_step_index < len(state.timesteps), state.timesteps[prev_step_index], timestep)

# 2. compute alphas, betas
alpha_prod_t = state.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jax.lax.select(prev_timestep >=0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand All @@ -360,15 +358,14 @@ def step(

# 4. Compute the predicted original sample x_0 based on the model parameterization
if self.config.prediction_type == "epsilon": # noise-prediction
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
predicted_original_sample = (sample - jnp.sqrt(beta_prod_t) * model_output) / jnp.sqrt(alpha_prod_t)
elif self.config.prediction_type == "sample": # x-prediction
predicted_original_sample = model_output
elif self.config.prediction_type == "v_prediction": # v-prediction
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
predicted_original_sample = jnp.sqrt(alpha_prod_t) * sample - jnp.sqrt(beta_prod_t.sqrt) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for `LCMScheduler`."
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or `v_prediction` for `LCMScheduler`."
)

# 5. Clip or threshold "predicted x_0"
Expand Down Expand Up @@ -400,7 +397,7 @@ def step(
if not return_dict:
return (prev_sample, state)

return FlaxLCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
return FlaxLCMSchedulerOutput(prev_sample=prev_sample, state=state)

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
Expand Down

0 comments on commit a66445c

Please sign in to comment.