diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39cc..a7c095b2e 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -174,23 +174,27 @@ def forward(self, x, context=None, mask=None): context = default(context, x) k = self.to_k(context) v = self.to_v(context) + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40) + del q, k if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) + del mask - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + # attention, what we cannot get enough of, by halves + sim[4:] = sim[4:].softmax(dim=-1) + sim[:4] = sim[:4].softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + sim = einsum('b i j, b j d -> b i d', sim, v) + sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) + return self.to_out(sim) class BasicTransformerBlock(nn.Module): @@ -258,4 +262,4 @@ def forward(self, x, context=None): x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) - return x + x_in \ No newline at end of file + return x + x_in diff --git a/optimizedSD/inpaint_gradio.py b/optimizedSD/inpaint_gradio.py index 4b5653bc2..d42930244 100644 --- a/optimizedSD/inpaint_gradio.py +++ b/optimizedSD/inpaint_gradio.py @@ -1,29 +1,29 @@ +import argparse +import os +import re +import time +from contextlib import nullcontext +from itertools import islice +from random import randint + import gradio as gr import numpy as np import torch -from torchvision.utils import make_grid -import os, re from PIL import Image -import torch -import numpy as np -from random import randint +from einops import rearrange, repeat from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange -from torchvision.utils import make_grid -import time from pytorch_lightning import seed_everything from torch import autocast -from einops import rearrange, repeat -from contextlib import nullcontext -from ldm.util import instantiate_from_config +from torchvision.utils import make_grid +from tqdm import tqdm, trange from transformers import logging -import pandas as pd + +from ldm.util import instantiate_from_config from optimUtils import split_weighted_subprompts, logger + logging.set_verbosity_error() import mimetypes + mimetypes.init() mimetypes.add_type("application/javascript", ".js") @@ -43,7 +43,6 @@ def load_model_from_config(ckpt, verbose=False): def load_img(image, h0, w0): - image = image.convert("RGB") w, h = image.size print(f"loaded input image of size ({w}, {h})") @@ -60,85 +59,49 @@ def load_img(image, h0, w0): return 2.0 * image - 1.0 -def load_mask(mask, h0, w0, invert=False): - +def load_mask(mask, h0, w0, newH, newW, invert=False): image = mask.convert("RGB") w, h = image.size - print(f"loaded input image of size ({w}, {h})") - if(h0 is not None and w0 is not None): + print(f"loaded input mask of size ({w}, {h})") + if h0 is not None and w0 is not None: h, w = h0, w0 - + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 - print(f"New image size ({w}, {h})") - image = image.resize((64, 64), resample = Image.LANCZOS) + print(f"New mask size ({w}, {h})") + image = image.resize((newW, newH), resample=Image.LANCZOS) + # image = image.resize((64, 64), resample=Image.LANCZOS) image = np.array(image) if invert: print("inverted") - where_0, where_1 = np.where(image == 0),np.where(image == 255) + where_0, where_1 = np.where(image == 0), np.where(image == 255) image[where_0], image[where_1] = 255, 0 - image = image.astype(np.float32)/255.0 + image = image.astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return image -config = "optimizedSD/v1-inference.yaml" -ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" -sd = load_model_from_config(f"{ckpt}") -li, lo = [], [] -for key, v_ in sd.items(): - sp = key.split(".") - if (sp[0]) == "model": - if "input_blocks" in sp: - li.append(key) - elif "middle_block" in sp: - li.append(key) - elif "time_embed" in sp: - li.append(key) - else: - lo.append(key) -for key in li: - sd["model1." + key[6:]] = sd.pop(key) -for key in lo: - sd["model2." + key[6:]] = sd.pop(key) - -config = OmegaConf.load(f"{config}") - -model = instantiate_from_config(config.modelUNet) -_, _ = model.load_state_dict(sd, strict=False) -model.eval() - -modelCS = instantiate_from_config(config.modelCondStage) -_, _ = modelCS.load_state_dict(sd, strict=False) -modelCS.eval() - -modelFS = instantiate_from_config(config.modelFirstStage) -_, _ = modelFS.load_state_dict(sd, strict=False) -modelFS.eval() -del sd - def generate( - image, - prompt, - strength, - ddim_steps, - n_iter, - batch_size, - Height, - Width, - scale, - ddim_eta, - unet_bs, - device, - seed, - outdir, - img_format, - turbo, - full_precision, + image, + prompt, + strength, + ddim_steps, + n_iter, + batch_size, + Height, + Width, + scale, + ddim_eta, + unet_bs, + device, + seed, + outdir, + img_format, + turbo, + full_precision, ): - if seed == "": seed = randint(0, 1000000) seed = int(seed) @@ -146,10 +109,9 @@ def generate( sampler = "ddim" # Logging - logger(locals(), log_csv = "logs/inpaint_gradio_logs.csv") + logger(locals(), log_csv="logs/inpaint_gradio_logs.csv") init_image = load_img(image['image'], Height, Width).to(device) - mask = load_mask(image['mask'], Height, Width, True).to(device) model.unet_bs = unet_bs model.turbo = turbo @@ -161,10 +123,7 @@ def generate( modelCS.half() modelFS.half() init_image = init_image.half() - mask.half() - - mask = mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) - mask = repeat(mask, '1 ... -> b ...', b=batch_size) + # mask.half() tic = time.time() os.makedirs(outdir, exist_ok=True) @@ -182,6 +141,10 @@ def generate( init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space init_latent = repeat(init_latent, "1 ... -> b ...", b=batch_size) + mask = load_mask(image['mask'], Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device) + mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) + mask = repeat(mask, '1 ... -> b ...', b=batch_size) + if device != "cpu": mem = torch.cuda.memory_allocated() / 1e6 modelFS.to("cpu") @@ -237,7 +200,7 @@ def generate( z_enc = model.stochastic_encode( init_latent, torch.tensor([t_enc] * batch_size).to(device), seed, ddim_eta, ddim_steps) - + # decode it samples_ddim = model.sample( t_enc, @@ -245,15 +208,14 @@ def generate( z_enc, unconditional_guidance_scale=scale, unconditional_conditioning=uc, - mask = mask, - x_T = init_latent, - sampler = sampler, + mask=mask, + x_T=init_latent, + sampler=sampler, ) modelFS.to(device) print("saving images") for i in range(batch_size): - x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) all_samples.append(x_sample.to("cpu")) @@ -284,37 +246,77 @@ def generate( grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() txt = ( - "Samples finished in " - + str(round(time_taken, 3)) - + " minutes and exported to \n" - + sample_path - + "\nSeeds used = " - + seeds[:-1] + "Samples finished in " + + str(round(time_taken, 3)) + + " minutes and exported to \n" + + sample_path + + "\nSeeds used = " + + seeds[:-1] + ) + return Image.fromarray(grid.astype(np.uint8)), image['mask'], txt + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='txt2img using gradio') + parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path') + parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path') + args = parser.parse_args() + config = args.config_path + ckpt = args.ckpt_path + sd = load_model_from_config(f"{ckpt}") + li, lo = [], [] + for key, v_ in sd.items(): + sp = key.split(".") + if (sp[0]) == "model": + if "input_blocks" in sp: + li.append(key) + elif "middle_block" in sp: + li.append(key) + elif "time_embed" in sp: + li.append(key) + else: + lo.append(key) + for key in li: + sd["model1." + key[6:]] = sd.pop(key) + for key in lo: + sd["model2." + key[6:]] = sd.pop(key) + + config = OmegaConf.load(f"{config}") + + model = instantiate_from_config(config.modelUNet) + _, _ = model.load_state_dict(sd, strict=False) + model.eval() + + modelCS = instantiate_from_config(config.modelCondStage) + _, _ = modelCS.load_state_dict(sd, strict=False) + modelCS.eval() + + modelFS = instantiate_from_config(config.modelFirstStage) + _, _ = modelFS.load_state_dict(sd, strict=False) + modelFS.eval() + del sd + + demo = gr.Interface( + fn=generate, + inputs=[ + gr.Image(tool="sketch", type="pil"), + "text", + gr.Slider(0, 0.99, value=0.99, step=0.01), + gr.Slider(1, 1000, value=50), + gr.Slider(1, 100, step=1), + gr.Slider(1, 100, step=1), + gr.Slider(64, 4096, value=512, step=64), + gr.Slider(64, 4096, value=512, step=64), + gr.Slider(0, 50, value=7.5, step=0.1), + gr.Slider(0, 1, step=0.01), + gr.Slider(1, 2, value=1, step=1), + gr.Text(value="cuda"), + "text", + gr.Text(value="outputs/inpaint-samples"), + gr.Radio(["png", "jpg"], value='png'), + "checkbox", + "checkbox", + ], + outputs=["image", "image", "text"], ) - return Image.fromarray(grid.astype(np.uint8)), image['mask'],txt - - -demo = gr.Interface( - fn=generate, - inputs=[ - gr.Image(tool="sketch", type="pil"), - "text", - gr.Slider(0, 0.99, value=0.99, step = 0.01), - gr.Slider(1, 1000, value=50), - gr.Slider(1, 100, step=1), - gr.Slider(1, 100, step=1), - gr.Slider(64, 4096, value=512, step=64), - gr.Slider(64, 4096, value=512, step=64), - gr.Slider(0, 50, value=7.5, step=0.1), - gr.Slider(0, 1, step=0.01), - gr.Slider(1, 2, value=1, step=1), - gr.Text(value="cuda"), - "text", - gr.Text(value="outputs/inpaint-samples"), - gr.Radio(["png", "jpg"], value='png'), - "checkbox", - "checkbox", - ], - outputs=["image", "image", "text"], -) -demo.launch() \ No newline at end of file + demo.launch()