Skip to content

Commit

Permalink
Implement noise augmentation for SD 4X upscale model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jan 3, 2024
1 parent ef4f603 commit 8c64935
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 14 deletions.
2 changes: 1 addition & 1 deletion comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def __init__(

if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
Expand Down
12 changes: 8 additions & 4 deletions comfy/ldm/modules/diffusionmodules/upscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def register_schedule(self, beta_schedule="linear", timesteps=1000,
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
def q_sample(self, x_start, t, noise=None, seed=None):
if noise is None:
if seed is None:
noise = torch.randn_like(x_start)
else:
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)

Expand All @@ -69,12 +73,12 @@ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level

def forward(self, x, noise_level=None):
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level


Expand Down
22 changes: 17 additions & 5 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
import comfy.model_management
import comfy.conds
import comfy.ops
Expand Down Expand Up @@ -78,8 +78,9 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra

model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
Expand Down Expand Up @@ -368,20 +369,31 @@ def extra_conds(self, **kwargs):
class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)

def extra_conds(self, **kwargs):
out = {}

image = kwargs.get("concat_image", None)
noise = kwargs.get("noise", None)
noise_augment = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
seed = kwargs["seed"] - 10

noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)

if image is None:
image = torch.zeros_like(noise)[:,:3]

if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")

noise_level = torch.tensor([noise_level], device=device)
if noise_augment > 0:
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)

image = utils.resize_to_batch_size(image, noise.shape[0])

out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
return out
4 changes: 2 additions & 2 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
latent_image = model.process_latent_in(latent_image)

if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)

#make sure each cond area has an opposite one with the same area
for c in positive:
Expand Down
1 change: 1 addition & 0 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class SD_X4Upscaler(SD20):

unet_extra_config = {
"disable_self_attentions": [True, True, True, False],
"num_classes": 1000,
"num_heads": 8,
"num_head_channels": -1,
}
Expand Down
6 changes: 4 additions & 2 deletions comfy_extras/nodes_sdupscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def INPUT_TYPES(s):
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
# "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), #TODO
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
Expand All @@ -18,7 +18,7 @@ def INPUT_TYPES(s):

CATEGORY = "conditioning/upscale_diffusion"

def encode(self, images, positive, negative, scale_ratio):
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio))

Expand All @@ -30,11 +30,13 @@ def encode(self, images, positive, negative, scale_ratio):
for t in positive:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cp.append(n)

for t in negative:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cn.append(n)

latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
Expand Down

0 comments on commit 8c64935

Please sign in to comment.