Skip to content

Commit

Permalink
implement flax lcm scheduler set_timesteps
Browse files Browse the repository at this point in the history
  • Loading branch information
entrpn committed Dec 5, 2023
1 parent a66445c commit 1c84fe9
Showing 1 changed file with 81 additions and 67 deletions.
148 changes: 81 additions & 67 deletions src/diffusers/schedulers/scheduling_lcm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class LCMSchedulerState:
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
custom_timesteps: Optional[bool] = False
step_index: Optional[int] = None
step_index: Optional[int] = -1

@classmethod
def create(
Expand Down Expand Up @@ -204,16 +204,13 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> LCMSche
# Rescale for zero SNR
if self.config.rescale_betas_zero_snr:
common.betas = rescale_zero_terminal_snr

final_alpha_cumprod = jnp.array(1.0) if self.config.set_alpha_to_one else common.alphas_cumprod[0]

# standard deviation of the initial noise distribution
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)

timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
custom_timesteps = False
step_index = None

step_index = -1

return LCMSchedulerState.create(
common=common,
Expand All @@ -223,7 +220,11 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> LCMSche
custom_timesteps=custom_timesteps,
step_index=step_index
)


def _init_step_index(self, state, timestep):
(step_index,) = jnp.where(state.timesteps == timestep, size=2)
step_index = jax.lax.select(len(step_index) > 1, step_index[1], step_index[0])
return step_index

def scale_model_input(self, state: LCMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None) -> jnp.ndarray:
"""
Expand All @@ -241,10 +242,16 @@ def scale_model_input(self, state: LCMSchedulerState, sample: jnp.ndarray, times
"""
return sample

# Copied from diffusers.schedulers.scheduling_ddim_flax
def set_timesteps(
self, state: LCMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LCMSchedulerState:
self,
state: LCMSchedulerState,
shape: Tuple = (),
num_inference_steps: Optional[int] = None,
original_inference_steps: Optional[int] = None,
timesteps: Optional[List[int]] = None,
strength: int = 1.0) -> LCMSchedulerState:


"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -254,29 +261,62 @@ def set_timesteps(
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""

# 0. Check inputs
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")

if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")

# 1. Calculate the LCM original training/distillation timestep schedule.
original_steps = (
original_inference_steps if original_inference_steps is not None else self.original_inference_steps
)
print("num_inference_steps: ", num_inference_steps)
print("original_steps: ", original_steps)

if original_steps > self.config.num_train_timesteps:
raise ValueError(
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)

# LCM Timesteps Setting
# The skipping step parameter k from the paper
k = self.config.num_train_timesteps // original_steps
# LCM Training/Distillation Steps Schedule
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
lcm_origin_timesteps = jnp.array(list(range(1, int(original_steps * strength) + 1))) * k - 1

skipping_step = len(lcm_origin_timesteps) // num_inference_steps

if skipping_step < 1:
raise ValueError(
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
)

if num_inference_steps > original_steps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
f" {original_steps} because the final timestep schedule will be a subset of the"
f" `original_inference_steps`-sized initial timestep schedule."
)

# LCM Inference Steps Schedule
lcm_origin_timesteps = lcm_origin_timesteps[::-1]
inference_indices = jnp.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
inference_indices = jnp.floor(inference_indices).astype(jnp.int32)
timesteps = lcm_origin_timesteps[jnp.array(inference_indices)]
jax.debug.print("timesteps: {x}", x=timesteps)
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset

return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
step_index=-1
)

# Copied from diffusers.schedulers.scheduling_ddim_flax
def _get_variance(self, state: LCMSchedulerState, timestep, prev_timestep):
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

return variance


def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
Expand All @@ -286,25 +326,6 @@ def get_scalings_for_boundary_condition_discrete(self, timestep):
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out

def set_timesteps(
self, state: LCMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LCMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`DDIMSchedulerState`):
the `FlaxDDPMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
timesteps = (jnp.arange(0, num_inference_steps))[::-1]

return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
)

def step(
self,
state: LCMSchedulerState,
Expand Down Expand Up @@ -334,20 +355,22 @@ def step(
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""

if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]


step_index = jax.lax.select(state.step_index == -1, self._init_step_index(state, timestep), state.step_index)

# 1. get previous step value
prev_step_index = step_index + 1
prev_timestep = jax.lax.select(prev_step_index < len(state.timesteps), state.timesteps[prev_step_index], timestep)

# 2. compute alphas, betas
jax.debug.print("timestep: {x}",x=timestep)
alpha_prod_t = state.common.alphas_cumprod[timestep]
jax.debug.print("alpha_prod_t: {x}",x=alpha_prod_t)
alpha_prod_t_prev = jax.lax.select(prev_timestep >=0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod)

beta_prod_t = 1 - alpha_prod_t
Expand Down Expand Up @@ -392,7 +415,12 @@ def step(
prev_sample = denoised

# # upon completion increase step index by one
# self._step_index += 1
jax.debug.print("step_index: {x}",x=step_index)
jax.debug.print("state.step_index: {x}",x=state.step_index)
state = state.replace(
step_index=step_index + 1
)
jax.debug.print("state.step_index next: {x}",x=state.step_index)

if not return_dict:
return (prev_sample, state)
Expand All @@ -407,6 +435,8 @@ def add_noise(
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
jax.debug.print("add noise!!!!!!!!!!!!!!!!!!!!!!!")
print("add noise!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
return add_noise_common(state.common, original_samples, noise, timesteps)


Expand All @@ -418,20 +448,4 @@ def get_velocity(


def __len__(self):
return self.config.num_train_timesteps

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
# def previous_timestep(self, timestep):
# if self.custom_timesteps:
# index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
# if index == self.timesteps.shape[0] - 1:
# prev_t = torch.tensor(-1)
# else:
# prev_t = self.timesteps[index + 1]
# else:
# num_inference_steps = (
# self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
# )
# prev_t = timestep - self.config.num_train_timesteps // num_inference_steps

# return prev_t
return self.config.num_train_timesteps

0 comments on commit 1c84fe9

Please sign in to comment.