diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 1016ce69e450..e3b8e2f0f30c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -56,5 +56,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: List[bool] + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 675b61266285..974d77547e56 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -187,7 +187,9 @@ def loop_body(step, args): latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state - scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps) + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) if debug: # run with python for loop diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 8344505620c4..4b4172213fa7 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config @@ -155,7 +156,7 @@ def __init__( def create_state(self): return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -196,8 +197,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> return state.replace( timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), - ets=jnp.array([]), counter=0, + # Reserve space for the state variables + cur_model_output=jnp.zeros(shape), + cur_sample=jnp.zeros(shape), + ets=jnp.zeros((4,) + shape), ) def step( @@ -227,22 +231,32 @@ def step( When returning a tuple, the first element is the sample tensor. """ - if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: - return self.step_prk( - state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + if self.config.skip_prk_steps: + prev_sample, state = self.step_plms( + state=state, model_output=model_output, timestep=timestep, sample=sample ) else: - return self.step_plms( - state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + prev_sample, state = jax.lax.switch( + jnp.where(state.counter < len(state.prk_timesteps), 0, 1), + (self.step_prk, self.step_plms), + # Args to either branch + state, + model_output, + timestep, + sample, ) + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + def step_prk( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - return_dict: bool = True, ) -> Union[FlaxSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the @@ -266,34 +280,46 @@ def step_prk( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 + diff_to_prev = jnp.where( + state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 + ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] - if state.counter % 4 == 0: - state = state.replace( - cur_model_output=state.cur_model_output + 1 / 6 * model_output, - ets=state.ets.append(model_output), - cur_sample=sample, + def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + return ( + state.replace( + cur_model_output=state.cur_model_output + 1 / 6 * model_output, + ets=state.ets.at[ets_at].set(model_output), + cur_sample=sample, + ), + model_output, ) - elif (self.counter - 1) % 4 == 0: - state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - elif (self.counter - 2) % 4 == 0: - state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - elif (self.counter - 3) % 4 == 0: - model_output = state.cur_model_output + 1 / 6 * model_output - state = state.replace(cur_model_output=0) - # cur_sample should not be `None` - cur_sample = state.cur_sample if state.cur_sample is not None else sample + def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output + def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output + + def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): + model_output = state.cur_model_output + 1 / 6 * model_output + return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output + + state, model_output = jax.lax.switch( + state.counter % 4, + (remainder_0, remainder_1, remainder_2, remainder_3), + # Args to either branch + state, + model_output, + state.counter // 4, + ) + + cur_sample = state.cur_sample prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) - if not return_dict: - return (prev_sample, state) - - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return (prev_sample, state) def step_plms( self, @@ -301,7 +327,6 @@ def step_plms( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - return_dict: bool = True, ) -> Union[FlaxSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple @@ -334,36 +359,91 @@ def step_plms( ) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) + + # Reference: + # if state.counter != 1: + # state.ets.append(model_output) + # else: + # prev_timestep = timestep + # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps + + prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) + timestep = jnp.where( + state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep + ) - if state.counter != 1: - state = state.replace(ets=state.ets.append(model_output)) - else: - prev_timestep = timestep - timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps - - if len(state.ets) == 1 and state.counter == 0: - model_output = model_output - state = state.replace(cur_sample=sample) - elif len(state.ets) == 1 and state.counter == 1: - model_output = (model_output + state.ets[-1]) / 2 - sample = state.cur_sample - state = state.replace(cur_sample=None) - elif len(state.ets) == 2: - model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 - elif len(state.ets) == 3: - model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 - else: - model_output = (1 / 24) * ( - 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] + # Reference: + # if len(state.ets) == 1 and state.counter == 0: + # model_output = model_output + # state.cur_sample = sample + # elif len(state.ets) == 1 and state.counter == 1: + # model_output = (model_output + state.ets[-1]) / 2 + # sample = state.cur_sample + # state.cur_sample = None + # elif len(state.ets) == 2: + # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 + # elif len(state.ets) == 3: + # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 + # else: + # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) + + def counter_0(state: PNDMSchedulerState): + ets = state.ets.at[0].set(model_output) + return state.replace( + ets=ets, + cur_sample=sample, + cur_model_output=jnp.array(model_output, dtype=jnp.float32), + ) + + def counter_1(state: PNDMSchedulerState): + return state.replace( + cur_model_output=(model_output + state.ets[0]) / 2, ) + def counter_2(state: PNDMSchedulerState): + ets = state.ets.at[1].set(model_output) + return state.replace( + ets=ets, + cur_model_output=(3 * ets[1] - ets[0]) / 2, + cur_sample=sample, + ) + + def counter_3(state: PNDMSchedulerState): + ets = state.ets.at[2].set(model_output) + return state.replace( + ets=ets, + cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12, + cur_sample=sample, + ) + + def counter_other(state: PNDMSchedulerState): + ets = state.ets.at[3].set(model_output) + next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0]) + + ets = ets.at[0].set(ets[1]) + ets = ets.at[1].set(ets[2]) + ets = ets.at[2].set(ets[3]) + + return state.replace( + ets=ets, + cur_model_output=next_model_output, + cur_sample=sample, + ) + + counter = jnp.clip(state.counter, 0, 4) + state = jax.lax.switch( + counter, + [counter_0, counter_1, counter_2, counter_3, counter_other], + state, + ) + + sample = state.cur_sample + model_output = state.cur_model_output prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) - if not return_dict: - return (prev_sample, state) - - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + return (prev_sample, state) def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf @@ -379,7 +459,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev