Skip to content

Commit

Permalink
DPM++ third order fixes (#9104)
Browse files Browse the repository at this point in the history
* Fix wrong output on 3n-1 steps count

* Add sde handling to 3 order

* make

* copies

---------

Co-authored-by: hlky <[email protected]>
  • Loading branch information
StAlKeR7779 and hlky authored Dec 3, 2024
1 parent 2be66e6 commit 8ac6de9
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline",
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@ def multistep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -967,6 +968,15 @@ def multistep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t

def index_for_timestep(self, timestep, schedule_timesteps=None):
Expand Down Expand Up @@ -1073,7 +1083,7 @@ def step(
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)

if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ def multistep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -842,6 +843,15 @@ def multistep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t

def _init_step_index(self, timestep):
Expand Down
24 changes: 23 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
orders = [1, 2] * (steps // 2)
elif order == 1:
orders = [1] * steps

if self.config.final_sigmas_type == "zero":
orders[-1] = 1

return orders

@property
Expand Down Expand Up @@ -812,6 +816,7 @@ def singlestep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -909,6 +914,23 @@ def singlestep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t

def singlestep_dpm_solver_update(
Expand Down Expand Up @@ -970,7 +992,7 @@ def singlestep_dpm_solver_update(
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
else:
raise ValueError(f"Order must be 1, 2, 3, got {order}")

Expand Down

0 comments on commit 8ac6de9

Please sign in to comment.