Skip to content

Commit

Permalink
EulerDiscrete upcast samples during step()
Browse files Browse the repository at this point in the history
Fixes the ZSNR precision issues on fp16/bf16 with no measureable
performance loss. Now using the full 2 ** -24, the results are
effectively equivalent to DDIM's ZSNR rescaling
  • Loading branch information
Beinsezii committed Dec 2, 2023
1 parent 4f580e4 commit 6d78118
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def __init__(
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

if rescale_betas_zero_snr:
# Bandaid so first sigma isn't inf
# Lower values that follow the 'proper' curve have precision issues on fp16/bf16
self.alphas_cumprod[-1] = 2 ** -16
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2 ** -24

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
Expand Down Expand Up @@ -468,6 +468,9 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)

sigma = self.sigmas[self.step_index]

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
Expand Down Expand Up @@ -504,6 +507,9 @@ def step(

prev_sample = sample + derivative * dt

# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)

# upon completion increase step index by one
self._step_index += 1

Expand Down

0 comments on commit 6d78118

Please sign in to comment.