From cd088a4ae3c65fdf059eb2482658191737bf4af6 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Tue, 26 Dec 2023 08:54:30 -0800 Subject: [PATCH] Add rescale_betas_zero_snr Argument to DDPMScheduler (#6305) * Add rescale_betas_zero_snr argument to DDPMScheduler. * Propagate rescale_betas_zero_snr changes to DDPMParallelScheduler. --------- Co-authored-by: Sayak Paul --- src/diffusers/schedulers/scheduling_ddpm.py | 46 +++++++++++++++++++ .../schedulers/scheduling_ddpm_parallel.py | 46 +++++++++++++++++++ tests/schedulers/test_scheduler_ddpm.py | 4 ++ .../test_scheduler_ddpm_parallel.py | 4 ++ 4 files changed, 100 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index c4a3eb43577c4..868cf1c2d8f23 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -89,6 +89,43 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDPMScheduler(SchedulerMixin, ConfigMixin): """ `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. @@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -153,6 +194,7 @@ def __init__( sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, + rescale_betas_zero_snr: int = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -171,6 +213,10 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 6f2bebfb5a381..9a84bfdf285dc 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -91,6 +91,43 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and @@ -139,6 +176,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -163,6 +204,7 @@ def __init__( sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, + rescale_betas_zero_snr: int = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -181,6 +223,10 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) diff --git a/tests/schedulers/test_scheduler_ddpm.py b/tests/schedulers/test_scheduler_ddpm.py index 4e2a3c74d8772..056b5d83350e4 100644 --- a/tests/schedulers/test_scheduler_ddpm.py +++ b/tests/schedulers/test_scheduler_ddpm.py @@ -68,6 +68,10 @@ def test_variance(self): assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() diff --git a/tests/schedulers/test_scheduler_ddpm_parallel.py b/tests/schedulers/test_scheduler_ddpm_parallel.py index b25f7151e1df0..4c33c090b0674 100644 --- a/tests/schedulers/test_scheduler_ddpm_parallel.py +++ b/tests/schedulers/test_scheduler_ddpm_parallel.py @@ -82,6 +82,10 @@ def test_variance(self): assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_batch_step_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config()