From 2cf481af1d88ea76c3dc5cc422b134547e0a38af Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 15 Dec 2023 01:34:21 -0800 Subject: [PATCH 1/3] EulerAncestral add `rescale_betas_zero_snr` Uses same infinite sigma fix from EulerDiscrete. Interestingly the ancestral version had the opposite problem: too much contrast instead of too little. --- .../scheduling_euler_ancestral_discrete.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index e476c329455e..424fee4b6e78 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -92,6 +92,43 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. @@ -122,6 +159,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -138,6 +179,7 @@ def __init__( prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -152,9 +194,17 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) From 0adbbb683364089a2a95d00f89d9f2050a453fe7 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 15 Dec 2023 01:39:25 -0800 Subject: [PATCH 2/3] UT for EulerAncestral `rescale_betas_zero_snr` --- tests/schedulers/test_scheduler_euler_ancestral.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py index a0818042fad9..9f22ab38ddaf 100644 --- a/tests/schedulers/test_scheduler_euler_ancestral.py +++ b/tests/schedulers/test_scheduler_euler_ancestral.py @@ -37,6 +37,10 @@ def test_prediction_type(self): for prediction_type in ["epsilon", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() From 04a763dd316c5f614b2b8a0deff1069ef0bd4d3a Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 15 Dec 2023 02:01:02 -0800 Subject: [PATCH 3/3] EulerAncestral upcast samples during step() It helps this scheduler too, particularly when the model is using bf16. While the noise dtype is still the model's it's automatically upcasted for the add so all it affects is determinism. --- .../schedulers/scheduling_euler_ancestral_discrete.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 424fee4b6e78..ca188378a38f 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -377,6 +377,9 @@ def step( sigma = self.sigmas[self.step_index] + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output @@ -407,6 +410,9 @@ def step( prev_sample = prev_sample + noise * sigma_up + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + # upon completion increase step index by one self._step_index += 1