Skip to content

Commit

Permalink
Merge branch 'comfyanonymous:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
WillReynolds5 authored Oct 26, 2024
2 parents 4e26292 + 5cbb01b commit 8001605
Show file tree
Hide file tree
Showing 31 changed files with 1,975 additions and 120 deletions.
10 changes: 3 additions & 7 deletions app/frontend_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,16 @@ def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndPr
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
# Use tmp path until complete to avoid path exists check passing from interrupted downloads
tmp_path = web_root + ".tmp"
try:
os.makedirs(tmp_path, exist_ok=True)
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
tmp_path,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=tmp_path)
if os.listdir(tmp_path):
os.rename(tmp_path, web_root)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
if not os.listdir(web_root):
Expand Down
21 changes: 10 additions & 11 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class StrengthType(Enum):
LINEAR_UP = 2

class ControlBase:
def __init__(self, device=None):
def __init__(self):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
Expand All @@ -72,10 +72,6 @@ def __init__(self, device=None):
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}

if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
Expand Down Expand Up @@ -185,8 +181,8 @@ def set_extra_arg(self, argument, value=None):


class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__(device)
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__()
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
Expand Down Expand Up @@ -242,7 +238,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)

self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

Expand Down Expand Up @@ -341,8 +337,8 @@ def forward(self, input):


class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
ControlBase.__init__(self, device)
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
ControlBase.__init__(self)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
Expand Down Expand Up @@ -662,12 +658,15 @@ def load_controlnet(ckpt_path, model=None, model_options={}):

class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
super().__init__()
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device

def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
Expand Down
2 changes: 2 additions & 0 deletions comfy/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)

inf = torch.finfo(dtype)
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
return sign


Expand Down
25 changes: 25 additions & 0 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,

@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
Expand All @@ -181,6 +183,29 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x

@torch.no_grad()
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})

# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x

@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
Expand Down
27 changes: 27 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,30 @@ def process_in(self, latent):

def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor

class Mochi(LatentFormat):
latent_channels = 12

def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
0.959253732819592, 0.8244560132752793, 0.917259975397747,
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)

self.latent_rgb_factors = None #TODO
self.taesd_decoder_name = None #TODO

def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std

def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
14 changes: 10 additions & 4 deletions comfy/ldm/common_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
except:
rms_norm_torch = None

def rms_norm(x, weight, eps=1e-6):
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
Loading

0 comments on commit 8001605

Please sign in to comment.