From ccf6f62e2f4ff2e02fbcdeb952aa3f6a0b39d1ab Mon Sep 17 00:00:00 2001 From: Beinsezii <39478211+Beinsezii@users.noreply.github.com> Date: Wed, 6 Dec 2023 23:51:04 -0800 Subject: [PATCH] EulerDiscreteScheduler add `rescale_betas_zero_snr` (#6024) * EulerDiscreteScheduler add `rescale_betas_zero_snr` --- schedulers/scheduling_euler_discrete.py | 56 +++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/schedulers/scheduling_euler_discrete.py b/schedulers/scheduling_euler_discrete.py index 0e2dd5c983e3..802ba0f099f9 100644 --- a/schedulers/scheduling_euler_discrete.py +++ b/schedulers/scheduling_euler_discrete.py @@ -92,6 +92,43 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler. @@ -128,6 +165,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -149,6 +190,7 @@ def __init__( timestep_spacing: str = "linspace", timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -163,9 +205,17 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() @@ -420,6 +470,9 @@ def step( if self.step_index is None: self._init_step_index(timestep) + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + sigma = self.sigmas[self.step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 @@ -456,6 +509,9 @@ def step( prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + # upon completion increase step index by one self._step_index += 1