diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6f70a8191629..db46dc1d8801 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -338,8 +338,8 @@ "StableDiffusion3ControlNetPipeline", "StableDiffusion3Img2ImgPipeline", "StableDiffusion3InpaintPipeline", - "StableDiffusion3PAGPipeline", "StableDiffusion3PAGImg2ImgPipeline", + "StableDiffusion3PAGPipeline", "StableDiffusion3Pipeline", "StableDiffusionAdapterPipeline", "StableDiffusionAttendAndExcitePipeline", diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 4b21328dccb5..e7704f2ced19 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -889,6 +889,7 @@ def multistep_dpm_solver_third_order_update( model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -967,6 +968,15 @@ def multistep_dpm_solver_third_order_update( - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) return x_t def index_for_timestep(self, timestep, schedule_timesteps=None): @@ -1073,7 +1083,7 @@ def step( elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 9f10d39ed40c..2968d0ef7b8e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -764,6 +764,7 @@ def multistep_dpm_solver_third_order_update( model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -842,6 +843,15 @@ def multistep_dpm_solver_third_order_update( - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) return x_t def _init_step_index(self, timestep): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 868122971e40..02af15ae5c6a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -264,6 +264,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: orders = [1, 2] * (steps // 2) elif order == 1: orders = [1] * steps + + if self.config.final_sigmas_type == "zero": + orders[-1] = 1 + return orders @property @@ -812,6 +816,7 @@ def singlestep_dpm_solver_third_order_update( model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -909,6 +914,23 @@ def singlestep_dpm_solver_third_order_update( - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s2 * torch.exp(-h)) * sample + + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s2 * torch.exp(-h)) * sample + + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) return x_t def singlestep_dpm_solver_update( @@ -970,7 +992,7 @@ def singlestep_dpm_solver_update( elif order == 2: return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise) elif order == 3: - return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) + return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise) else: raise ValueError(f"Order must be 1, 2, 3, got {order}")