diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index ffdd888ee14..a829316273d 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -280,6 +280,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): + return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) + """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -306,6 +309,39 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x +@torch.no_grad() +def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta + sigma_down = sigmas[i+1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i+1] + alpha_down = 1 - sigma_down + renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + if sigma_down == 0: + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + return x def linear_multistep_coeff(order, t, i, j): if order - 1 > i: