Skip to content

Commit

Permalink
EulerAncestral add rescale_betas_zero_snr (huggingface#6187)
Browse files Browse the repository at this point in the history
* EulerAncestral add `rescale_betas_zero_snr`

Uses same infinite sigma fix from EulerDiscrete. Interestingly the
ancestral version had the opposite problem: too much contrast instead of
too little.

* UT for EulerAncestral `rescale_betas_zero_snr`

* EulerAncestral upcast samples during step()

It helps this scheduler too, particularly when the model is using bf16.

While the noise dtype is still the model's it's automatically upcasted
for the add so all it affects is determinism.

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
2 people authored and donhardman committed Dec 29, 2023
1 parent ed3fcd8 commit 1d1c873
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,43 @@ def alpha_bar_fn(t):
return torch.tensor(betas, dtype=torch.float32)


# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas

return betas


class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Ancestral sampling with Euler method steps.
Expand Down Expand Up @@ -122,6 +159,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -138,6 +179,7 @@ def __init__(
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand All @@ -152,9 +194,17 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
Expand Down Expand Up @@ -327,6 +377,9 @@ def step(

sigma = self.sigmas[self.step_index]

# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
Expand Down Expand Up @@ -357,6 +410,9 @@ def step(

prev_sample = prev_sample + noise * sigma_up

# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)

# upon completion increase step index by one
self._step_index += 1

Expand Down
4 changes: 4 additions & 0 deletions tests/schedulers/test_scheduler_euler_ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)

def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)

def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
Expand Down

0 comments on commit 1d1c873

Please sign in to comment.