Skip to content

Commit

Permalink
Fix a bug in add_noise function (huggingface#6085)
Browse files Browse the repository at this point in the history
* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
2 people authored and Jimmy committed Apr 26, 2024
1 parent 8c558ac commit 36a587a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 5 deletions.
11 changes: 10 additions & 1 deletion src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,16 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,16 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,16 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,16 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,16 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down

0 comments on commit 36a587a

Please sign in to comment.