From daa2a411fcde579b5b4be154d05a4edf9d4cd02f Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 1 Dec 2023 13:31:48 -0800 Subject: [PATCH 1/5] EulerDiscreteScheduler add `rescale_betas_zero_snr` Adds support for the `rescale_betas_zero_snr` config option to EulerDiscreteScheduler. Currently it works equivelently to DDIM with the concession that the final alpha_cumprod is patched to 2 ** -16 to resolve the `inf` issue, similar to what ComfyUI does. This does not follow the 'curve' so to speak, and a value closer to 0 would be more appropriate, however lower values such as 2 ** -24 run into precision issues on fp16/bf16 inference. A more proper fix would be finding exactly where the precision issues lie in the pipeline code and resolving them (upcast?). I don't think I am math enough for this; 2 ** -16 is satisfactory from my benchmark images on the model "ptx0/terminus-xl-gamma-training". --- .../schedulers/scheduling_euler_discrete.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 53dc2ae15432..99f341533b89 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -92,6 +92,42 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler. @@ -128,6 +164,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -149,6 +189,7 @@ def __init__( timestep_spacing: str = "linspace", timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -163,9 +204,17 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + if rescale_betas_zero_snr: + # Bandaid so first sigma isn't inf + # Lower values that follow the 'proper' curve have precision issues on fp16/bf16 + self.alphas_cumprod[-1] = 2 ** -16 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() From 0de1ef8c10c4e46c7877e06db658ffbca5bc965c Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 1 Dec 2023 16:49:09 -0800 Subject: [PATCH 2/5] EulerDiscrete upcast samples during step() Fixes the ZSNR precision issues on fp16/bf16 with no measureable performance loss. Now using the full 2 ** -24, the results are effectively equivalent to DDIM's ZSNR rescaling --- .../schedulers/scheduling_euler_discrete.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 99f341533b89..10ade4455517 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -211,9 +211,9 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) if rescale_betas_zero_snr: - # Bandaid so first sigma isn't inf - # Lower values that follow the 'proper' curve have precision issues on fp16/bf16 - self.alphas_cumprod[-1] = 2 ** -16 + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2 ** -24 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() @@ -468,6 +468,9 @@ def step( if self.step_index is None: self._init_step_index(timestep) + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + sigma = self.sigmas[self.step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 @@ -504,6 +507,9 @@ def step( prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + # upon completion increase step index by one self._step_index += 1 From b39f2847132202c3eb52d408d67d84bcf5f2efd6 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 1 Dec 2023 16:55:39 -0800 Subject: [PATCH 3/5] EulerDiscrete run `ruff` for PR #6024 --- src/diffusers/schedulers/scheduling_euler_discrete.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 10ade4455517..4b5917c41d2c 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -213,7 +213,7 @@ def __init__( if rescale_betas_zero_snr: # Close to 0 without being 0 so first sigma is not inf # FP16 smallest positive subnormal works well here - self.alphas_cumprod[-1] = 2 ** -24 + self.alphas_cumprod[-1] = 2**-24 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() From e143868cd6d40c4e8216752d68872285ac1e88d6 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 5 Dec 2023 01:26:56 -0800 Subject: [PATCH 4/5] Add `# Copy from` to euler_discrete's rescale fn --- src/diffusers/schedulers/scheduling_euler_discrete.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4b5917c41d2c..8f2efc7d467f 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -92,6 +92,7 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) From 14d5a4cc034b98fc5612692402e721ace7c4122b Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 5 Dec 2023 01:35:46 -0800 Subject: [PATCH 5/5] Test `rescale_betas_zero_snr` for euler scheduler --- tests/schedulers/test_scheduler_euler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index 3249d7032bad..41c418c5064c 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -45,6 +45,10 @@ def test_timestep_type(self): def test_karras_sigmas(self): self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0) + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config()