diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 468fdf61a9ef..eb40d79b9f60 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -548,16 +548,12 @@ def __len__(self): return self.config.num_train_timesteps def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index f377ee6e8c93..20ad7a4c927d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -639,16 +639,12 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index f1aa09ab1723..686b686f6870 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -643,16 +643,12 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 580224404c54..5d60383142a4 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -680,16 +680,12 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t