Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EulerDiscreteScheduler add rescale_betas_zero_snr #6024

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,43 @@ def alpha_bar_fn(t):
return torch.tensor(betas, dtype=torch.float32)


# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this copied from ddim? if so let's add a # Copy from statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is copied from DDIM which itself is copied 1:1 (and renamed) from Page 3 Alrogithm 1 of the ZSNR paper.

"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)


Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.

Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas

return betas


class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
Expand Down Expand Up @@ -128,6 +165,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -149,6 +190,7 @@ def __init__(
timestep_spacing: str = "linspace",
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand All @@ -163,9 +205,17 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()

Expand Down Expand Up @@ -419,6 +469,9 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)

sigma = self.sigmas[self.step_index]

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
Expand Down Expand Up @@ -455,6 +508,9 @@ def step(

prev_sample = sample + derivative * dt

# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)

# upon completion increase step index by one
self._step_index += 1

Expand Down
4 changes: 4 additions & 0 deletions tests/schedulers/test_scheduler_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def test_timestep_type(self):
def test_karras_sigmas(self):
self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0)

def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)

def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
Expand Down
Loading