diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 6c5d9a3993d40..1423920191d7d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -896,8 +896,20 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) - timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) - return torch.tensor(timesteps), len(timesteps) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2: + # if the scheduler is a 2nd order scheduler we ALWAYS have to do +1 + # because `num_inference_steps` will always be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 825c74ce07074..ff9d669a802bf 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -553,8 +553,20 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) - timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) - return torch.tensor(timesteps), len(timesteps) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2: + # if the scheduler is a 2nd order scheduler we ALWAYS have to do +1 + # because `num_inference_steps` will always be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 535cc7268305b..200f5a7bf45a0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -838,8 +838,20 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) - timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) - return torch.tensor(timesteps), len(timesteps) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2: + # if the scheduler is a 2nd order scheduler we ALWAYS have to do +1 + # because `num_inference_steps` will always be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 4906670890e82..f628ad741d840 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -328,8 +328,13 @@ class scheduler_cls(scheduler_cls_orig): pipe_1.scheduler.set_timesteps(num_steps) expected_steps = pipe_1.scheduler.timesteps.tolist() - expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) - expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss)) + expected_steps = expected_steps_1 + expected_steps_2 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss)) # now we monkey patch step `done_steps` # list into the step function for testing @@ -611,13 +616,18 @@ class scheduler_cls(scheduler_cls_orig): split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1)) split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2)) - expected_steps_1 = expected_steps[:split_1_ts] - expected_steps_2 = expected_steps[split_1_ts:split_2_ts] - expected_steps_3 = expected_steps[split_2_ts:] - expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) - expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) - expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = expected_steps_1[-1:] + list( + filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps) + ) + expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps)) + expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) + expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) # now we monkey patch step `done_steps` # list into the step function for testing diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 7e3698d8ca167..898fda0d7bb4e 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -318,11 +318,14 @@ class scheduler_cls(scheduler_cls_orig): expected_steps = pipe_1.scheduler.timesteps.tolist() split_ts = num_train_timesteps - int(round(num_train_timesteps * split)) - expected_steps_1 = expected_steps[:split_ts] - expected_steps_2 = expected_steps[split_ts:] - expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps)) - expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps)) + expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps)) + expected_steps = expected_steps_1 + expected_steps_2 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps)) + expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps)) # now we monkey patch step `done_steps` # list into the step function for testing @@ -389,13 +392,18 @@ class scheduler_cls(scheduler_cls_orig): split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1)) split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2)) - expected_steps_1 = expected_steps[:split_1_ts] - expected_steps_2 = expected_steps[split_1_ts:split_2_ts] - expected_steps_3 = expected_steps[split_2_ts:] - expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) - expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) - expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = expected_steps_1[-1:] + list( + filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps) + ) + expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps)) + expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) + expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) # now we monkey patch step `done_steps` # list into the step function for testing