Skip to content

Commit

Permalink
Add rescale_betas_zero_snr Argument to DDPMScheduler (#6305)
Browse files Browse the repository at this point in the history
* Add rescale_betas_zero_snr argument to DDPMScheduler.

* Propagate rescale_betas_zero_snr changes to DDPMParallelScheduler.

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
dg845 and sayakpaul authored Dec 26, 2023
1 parent d4f10ea commit 3706aa3
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,43 @@ def alpha_bar_fn(t):
return torch.tensor(betas, dtype=torch.float32)


# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas

return betas


class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
Expand Down Expand Up @@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -153,6 +194,7 @@ def __init__(
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
steps_offset: int = 0,
rescale_betas_zero_snr: int = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand All @@ -171,6 +213,10 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
Expand Down
46 changes: 46 additions & 0 deletions src/diffusers/schedulers/scheduling_ddpm_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,43 @@ def alpha_bar_fn(t):
return torch.tensor(betas, dtype=torch.float32)


# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas

return betas


class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Expand Down Expand Up @@ -139,6 +176,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -163,6 +204,7 @@ def __init__(
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
steps_offset: int = 0,
rescale_betas_zero_snr: int = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand All @@ -181,6 +223,10 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
Expand Down
4 changes: 4 additions & 0 deletions tests/schedulers/test_scheduler_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def test_variance(self):
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5

def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)

def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
Expand Down
4 changes: 4 additions & 0 deletions tests/schedulers/test_scheduler_ddpm_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def test_variance(self):
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5

def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)

def test_batch_step_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
Expand Down

0 comments on commit 3706aa3

Please sign in to comment.