Skip to content

Commit

Permalink
fix a bug in 2nd order schedulers when using in ensemble of experts c…
Browse files Browse the repository at this point in the history
…onfig (huggingface#5511)

* fix

* fix copies

* remove heun from tests

* add back heun and fix the tests to include 2nd order

* fix the other test too

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* make style

* add more comments

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored and kashif committed Nov 11, 2023
1 parent d6ee760 commit 52f964d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -896,8 +896,20 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1

# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,20 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1

# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,20 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1

# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

Expand Down
26 changes: 18 additions & 8 deletions tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,13 @@ class scheduler_cls(scheduler_cls_orig):
pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist()

expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down Expand Up @@ -611,13 +616,18 @@ class scheduler_cls(scheduler_cls_orig):

split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]

expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,14 @@ class scheduler_cls(scheduler_cls_orig):
expected_steps = pipe_1.scheduler.timesteps.tolist()

split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
expected_steps_1 = expected_steps[:split_ts]
expected_steps_2 = expected_steps[split_ts:]

expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down Expand Up @@ -389,13 +392,18 @@ class scheduler_cls(scheduler_cls_orig):

split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]

expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down

0 comments on commit 52f964d

Please sign in to comment.