Skip to content

Commit

Permalink
[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "t…
Browse files Browse the repository at this point in the history
…railing" and "linspace" options (#9384)

* Update scheduling_ddpm.py

* fix copies

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: hlky <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 619b965 commit 5effcd3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 24 deletions.
8 changes: 2 additions & 6 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,12 @@ def __len__(self):
return self.config.num_train_timesteps

def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps

prev_t = timestep - 1
return prev_t
8 changes: 2 additions & 6 deletions src/diffusers/schedulers/scheduling_ddpm_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,16 +639,12 @@ def __len__(self):

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps

prev_t = timestep - 1
return prev_t
8 changes: 2 additions & 6 deletions src/diffusers/schedulers/scheduling_lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,16 +643,12 @@ def __len__(self):

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps

prev_t = timestep - 1
return prev_t
8 changes: 2 additions & 6 deletions src/diffusers/schedulers/scheduling_tcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,16 +680,12 @@ def __len__(self):

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps

prev_t = timestep - 1
return prev_t

0 comments on commit 5effcd3

Please sign in to comment.