From 1c84fe972ae04dc9fc4d6c1f743b62ccfd87830b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 5 Dec 2023 23:48:53 +0000 Subject: [PATCH] implement flax lcm scheduler set_timesteps --- .../schedulers/scheduling_lcm_flax.py | 148 ++++++++++-------- 1 file changed, 81 insertions(+), 67 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lcm_flax.py b/src/diffusers/schedulers/scheduling_lcm_flax.py index 51a8c0711eb3..297334dd9fe6 100644 --- a/src/diffusers/schedulers/scheduling_lcm_flax.py +++ b/src/diffusers/schedulers/scheduling_lcm_flax.py @@ -66,7 +66,7 @@ class LCMSchedulerState: timesteps: jnp.ndarray num_inference_steps: Optional[int] = None custom_timesteps: Optional[bool] = False - step_index: Optional[int] = None + step_index: Optional[int] = -1 @classmethod def create( @@ -204,16 +204,13 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> LCMSche # Rescale for zero SNR if self.config.rescale_betas_zero_snr: common.betas = rescale_zero_terminal_snr - - final_alpha_cumprod = jnp.array(1.0) if self.config.set_alpha_to_one else common.alphas_cumprod[0] # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] custom_timesteps = False - step_index = None - + step_index = -1 return LCMSchedulerState.create( common=common, @@ -223,7 +220,11 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> LCMSche custom_timesteps=custom_timesteps, step_index=step_index ) - + + def _init_step_index(self, state, timestep): + (step_index,) = jnp.where(state.timesteps == timestep, size=2) + step_index = jax.lax.select(len(step_index) > 1, step_index[1], step_index[0]) + return step_index def scale_model_input(self, state: LCMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None) -> jnp.ndarray: """ @@ -241,10 +242,16 @@ def scale_model_input(self, state: LCMSchedulerState, sample: jnp.ndarray, times """ return sample - # Copied from diffusers.schedulers.scheduling_ddim_flax def set_timesteps( - self, state: LCMSchedulerState, num_inference_steps: int, shape: Tuple = () - ) -> LCMSchedulerState: + self, + state: LCMSchedulerState, + shape: Tuple = (), + num_inference_steps: Optional[int] = None, + original_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + strength: int = 1.0) -> LCMSchedulerState: + + """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -254,29 +261,62 @@ def set_timesteps( num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ + + # 0. Check inputs + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + # 1. Calculate the LCM original training/distillation timestep schedule. + original_steps = ( + original_inference_steps if original_inference_steps is not None else self.original_inference_steps + ) + print("num_inference_steps: ", num_inference_steps) + print("original_steps: ", original_steps) + + if original_steps > self.config.num_train_timesteps: + raise ValueError( + f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + # LCM Timesteps Setting + # The skipping step parameter k from the paper + k = self.config.num_train_timesteps // original_steps + # LCM Training/Distillation Steps Schedule + # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). + lcm_origin_timesteps = jnp.array(list(range(1, int(original_steps * strength) + 1))) * k - 1 + + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + + if skipping_step < 1: + raise ValueError( + f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." + ) + + if num_inference_steps > original_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" + f" {original_steps} because the final timestep schedule will be a subset of the" + f" `original_inference_steps`-sized initial timestep schedule." + ) + + # LCM Inference Steps Schedule + lcm_origin_timesteps = lcm_origin_timesteps[::-1] + inference_indices = jnp.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = jnp.floor(inference_indices).astype(jnp.int32) + timesteps = lcm_origin_timesteps[jnp.array(inference_indices)] + jax.debug.print("timesteps: {x}", x=timesteps) step_ratio = self.config.num_train_timesteps // num_inference_steps - # creates integer timesteps by multiplying by ratio - # rounding to avoid issues when num_inference_step is power of 3 - timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, + step_index=-1 ) - - # Copied from diffusers.schedulers.scheduling_ddim_flax - def _get_variance(self, state: LCMSchedulerState, timestep, prev_timestep): - alpha_prod_t = state.common.alphas_cumprod[timestep] - alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod - ) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - def get_scalings_for_boundary_condition_discrete(self, timestep): self.sigma_data = 0.5 # Default: 0.5 @@ -286,25 +326,6 @@ def get_scalings_for_boundary_condition_discrete(self, timestep): c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out - def set_timesteps( - self, state: LCMSchedulerState, num_inference_steps: int, shape: Tuple = () - ) -> LCMSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`DDIMSchedulerState`): - the `FlaxDDPMScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - timesteps = (jnp.arange(0, num_inference_steps))[::-1] - - return state.replace( - num_inference_steps=num_inference_steps, - timesteps=timesteps, - ) - def step( self, state: LCMSchedulerState, @@ -334,20 +355,22 @@ def step( If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ + if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - - (step_index,) = jnp.where(state.timesteps == timestep, size=1) - step_index = step_index[0] - + + step_index = jax.lax.select(state.step_index == -1, self._init_step_index(state, timestep), state.step_index) + # 1. get previous step value prev_step_index = step_index + 1 prev_timestep = jax.lax.select(prev_step_index < len(state.timesteps), state.timesteps[prev_step_index], timestep) # 2. compute alphas, betas + jax.debug.print("timestep: {x}",x=timestep) alpha_prod_t = state.common.alphas_cumprod[timestep] + jax.debug.print("alpha_prod_t: {x}",x=alpha_prod_t) alpha_prod_t_prev = jax.lax.select(prev_timestep >=0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t @@ -392,7 +415,12 @@ def step( prev_sample = denoised # # upon completion increase step index by one - # self._step_index += 1 + jax.debug.print("step_index: {x}",x=step_index) + jax.debug.print("state.step_index: {x}",x=state.step_index) + state = state.replace( + step_index=step_index + 1 + ) + jax.debug.print("state.step_index next: {x}",x=state.step_index) if not return_dict: return (prev_sample, state) @@ -407,6 +435,8 @@ def add_noise( noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: + jax.debug.print("add noise!!!!!!!!!!!!!!!!!!!!!!!") + print("add noise!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return add_noise_common(state.common, original_samples, noise, timesteps) @@ -418,20 +448,4 @@ def get_velocity( def __len__(self): - return self.config.num_train_timesteps - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep - # def previous_timestep(self, timestep): - # if self.custom_timesteps: - # index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] - # if index == self.timesteps.shape[0] - 1: - # prev_t = torch.tensor(-1) - # else: - # prev_t = self.timesteps[index + 1] - # else: - # num_inference_steps = ( - # self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - # ) - # prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - - # return prev_t \ No newline at end of file + return self.config.num_train_timesteps \ No newline at end of file