Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix a bug in 2nd order schedulers when using in ensemble of experts config #5511

Merged
merged 10 commits into from
Oct 25, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum()
if self.scheduler.order == 2:
num_inference_steps = num_inference_steps + 1
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think num_inference_steps will always be even. Consider the case where the scheduler has been set up by calling set_timesteps(25) and then supplied denoising_start = 0.6. The following code prints an even number:

sched = SCHEDULERS["DPM2"]
sched.set_timesteps(25)
denoising_start = 0.6
discrete_timestep_cutoff = int(
    round(
        sched.config.num_train_timesteps
        - (denoising_start * sched.config.num_train_timesteps)
    )
)
num_inference_steps = (sched.timesteps < discrete_timestep_cutoff).sum().item()
print(num_inference_steps)

This prints 20. For denoising_start=0.5, it prints 25. So num_inference_steps may be either odd or even depending on num_inference_steps and denoising_start passed into the pipeline call.

I think we should add 1 to it only if it's even, and not if it's odd. No?

@yiyixuxu @patrickvonplaten

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah DPM2 can actually be different indeed if it interpolates, (Heun can't)

Nice catch! I'll open a PR to fix it

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the result I get with hnf = 0.5 and steps = 25 with DPM2 with the code from this PR:
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @patrickvonplaten, all testcases I threw at it now produce good images.

I am just curious why we don't need to do any logic in the base step (where denoising_end is supplied to the base SDXL pipeline without img2img) since I would expect that we would want to ensure that the final timestep there is a second order timestep, and as far as I can tell it still has the odd/even issue when splitting. Maybe I am wrong!

In any case, latest main now produces good images for all my testcases like I said. Thank you! 🙏

timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum()
if self.scheduler.order == 2:
num_inference_steps = num_inference_steps + 1
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum()
if self.scheduler.order == 2:
num_inference_steps = num_inference_steps + 1
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

Expand Down
26 changes: 18 additions & 8 deletions tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,13 @@ class scheduler_cls(scheduler_cls_orig):
pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist()

expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down Expand Up @@ -607,13 +612,18 @@ class scheduler_cls(scheduler_cls_orig):

split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]

expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,14 @@ class scheduler_cls(scheduler_cls_orig):
expected_steps = pipe_1.scheduler.timesteps.tolist()

split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
expected_steps_1 = expected_steps[:split_ts]
expected_steps_2 = expected_steps[split_ts:]

expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down Expand Up @@ -389,13 +392,18 @@ class scheduler_cls(scheduler_cls_orig):

split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]

expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))

# now we monkey patch step `done_steps`
# list into the step function for testing
Expand Down