diff --git a/schedulers/scheduling_ddim_flax.py b/schedulers/scheduling_ddim_flax.py index d81d66607147..a57f61e07097 100644 --- a/schedulers/scheduling_ddim_flax.py +++ b/schedulers/scheduling_ddim_flax.py @@ -156,7 +156,7 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod): return variance - def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState: + def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/schedulers/scheduling_ddpm_flax.py b/schedulers/scheduling_ddpm_flax.py index 7c7b8d29ab52..8f631403ff7d 100644 --- a/schedulers/scheduling_ddpm_flax.py +++ b/schedulers/scheduling_ddpm_flax.py @@ -133,7 +133,7 @@ def __init__( self.variance_type = variance_type - def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState: + def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/schedulers/scheduling_karras_ve_flax.py b/schedulers/scheduling_karras_ve_flax.py index c320b79e6dcd..72ff69da0352 100644 --- a/schedulers/scheduling_karras_ve_flax.py +++ b/schedulers/scheduling_karras_ve_flax.py @@ -99,7 +99,9 @@ def __init__( ): self.state = KarrasVeSchedulerState.create() - def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState: + def set_timesteps( + self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple + ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/schedulers/scheduling_lms_discrete_flax.py b/schedulers/scheduling_lms_discrete_flax.py index 4784e4fafccb..51b3dfc72f99 100644 --- a/schedulers/scheduling_lms_discrete_flax.py +++ b/schedulers/scheduling_lms_discrete_flax.py @@ -111,7 +111,9 @@ def lms_derivative(tau): return integrated_coeff - def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState: + def set_timesteps( + self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple + ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/schedulers/scheduling_pndm_flax.py b/schedulers/scheduling_pndm_flax.py index 9e2b19f01301..871a406f68f3 100644 --- a/schedulers/scheduling_pndm_flax.py +++ b/schedulers/scheduling_pndm_flax.py @@ -156,7 +156,7 @@ def __init__( def create_state(self): return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/schedulers/scheduling_sde_ve_flax.py b/schedulers/scheduling_sde_ve_flax.py index 08fbe14732da..3c8834d7b178 100644 --- a/schedulers/scheduling_sde_ve_flax.py +++ b/schedulers/scheduling_sde_ve_flax.py @@ -95,7 +95,7 @@ def __init__( self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.