Skip to content

Commit

Permalink
Fix some issues with sampling precision.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 1, 2023
1 parent 7c0f255 commit 111f1b5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

timesteps, = betas.shape
Expand All @@ -56,7 +56,7 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5

self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
Expand Down
6 changes: 4 additions & 2 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def cond_cat(c_list):

def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
out_count = torch.zeros_like(x_in)

out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in)/100000.0
out_uncond_count = torch.zeros_like(x_in)

COND = 0
UNCOND = 1
Expand Down Expand Up @@ -241,6 +241,8 @@ def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_tot
out_uncond /= out_uncond_count
del out_uncond_count

torch.nan_to_num(out_cond, nan=0.0, posinf=0.0, neginf=0.0, out=out_cond) #in case out_count or out_uncond_count had some zeros
torch.nan_to_num(out_uncond, nan=0.0, posinf=0.0, neginf=0.0, out=out_uncond)
return out_cond, out_uncond


Expand Down

0 comments on commit 111f1b5

Please sign in to comment.