From f1bb248dbb585b78cf3e53c0d5713f6496a4e69c Mon Sep 17 00:00:00 2001 From: bironsecret <44505718+bironsecret@users.noreply.github.com> Date: Sun, 4 Sep 2022 09:48:32 +0200 Subject: [PATCH] inpaint gradio mask mode fixed (cherry picked from commit ddde264be6b81fa05e57d653ef080545d921850a) --- optimizedSD/inpaint_gradio.py | 246 +++++++++++++++++----------------- 1 file changed, 124 insertions(+), 122 deletions(-) diff --git a/optimizedSD/inpaint_gradio.py b/optimizedSD/inpaint_gradio.py index 4b5653bc2..d42930244 100644 --- a/optimizedSD/inpaint_gradio.py +++ b/optimizedSD/inpaint_gradio.py @@ -1,29 +1,29 @@ +import argparse +import os +import re +import time +from contextlib import nullcontext +from itertools import islice +from random import randint + import gradio as gr import numpy as np import torch -from torchvision.utils import make_grid -import os, re from PIL import Image -import torch -import numpy as np -from random import randint +from einops import rearrange, repeat from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange -from torchvision.utils import make_grid -import time from pytorch_lightning import seed_everything from torch import autocast -from einops import rearrange, repeat -from contextlib import nullcontext -from ldm.util import instantiate_from_config +from torchvision.utils import make_grid +from tqdm import tqdm, trange from transformers import logging -import pandas as pd + +from ldm.util import instantiate_from_config from optimUtils import split_weighted_subprompts, logger + logging.set_verbosity_error() import mimetypes + mimetypes.init() mimetypes.add_type("application/javascript", ".js") @@ -43,7 +43,6 @@ def load_model_from_config(ckpt, verbose=False): def load_img(image, h0, w0): - image = image.convert("RGB") w, h = image.size print(f"loaded input image of size ({w}, {h})") @@ -60,85 +59,49 @@ def load_img(image, h0, w0): return 2.0 * image - 1.0 -def load_mask(mask, h0, w0, invert=False): - +def load_mask(mask, h0, w0, newH, newW, invert=False): image = mask.convert("RGB") w, h = image.size - print(f"loaded input image of size ({w}, {h})") - if(h0 is not None and w0 is not None): + print(f"loaded input mask of size ({w}, {h})") + if h0 is not None and w0 is not None: h, w = h0, w0 - + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 - print(f"New image size ({w}, {h})") - image = image.resize((64, 64), resample = Image.LANCZOS) + print(f"New mask size ({w}, {h})") + image = image.resize((newW, newH), resample=Image.LANCZOS) + # image = image.resize((64, 64), resample=Image.LANCZOS) image = np.array(image) if invert: print("inverted") - where_0, where_1 = np.where(image == 0),np.where(image == 255) + where_0, where_1 = np.where(image == 0), np.where(image == 255) image[where_0], image[where_1] = 255, 0 - image = image.astype(np.float32)/255.0 + image = image.astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return image -config = "optimizedSD/v1-inference.yaml" -ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" -sd = load_model_from_config(f"{ckpt}") -li, lo = [], [] -for key, v_ in sd.items(): - sp = key.split(".") - if (sp[0]) == "model": - if "input_blocks" in sp: - li.append(key) - elif "middle_block" in sp: - li.append(key) - elif "time_embed" in sp: - li.append(key) - else: - lo.append(key) -for key in li: - sd["model1." + key[6:]] = sd.pop(key) -for key in lo: - sd["model2." + key[6:]] = sd.pop(key) - -config = OmegaConf.load(f"{config}") - -model = instantiate_from_config(config.modelUNet) -_, _ = model.load_state_dict(sd, strict=False) -model.eval() - -modelCS = instantiate_from_config(config.modelCondStage) -_, _ = modelCS.load_state_dict(sd, strict=False) -modelCS.eval() - -modelFS = instantiate_from_config(config.modelFirstStage) -_, _ = modelFS.load_state_dict(sd, strict=False) -modelFS.eval() -del sd - def generate( - image, - prompt, - strength, - ddim_steps, - n_iter, - batch_size, - Height, - Width, - scale, - ddim_eta, - unet_bs, - device, - seed, - outdir, - img_format, - turbo, - full_precision, + image, + prompt, + strength, + ddim_steps, + n_iter, + batch_size, + Height, + Width, + scale, + ddim_eta, + unet_bs, + device, + seed, + outdir, + img_format, + turbo, + full_precision, ): - if seed == "": seed = randint(0, 1000000) seed = int(seed) @@ -146,10 +109,9 @@ def generate( sampler = "ddim" # Logging - logger(locals(), log_csv = "logs/inpaint_gradio_logs.csv") + logger(locals(), log_csv="logs/inpaint_gradio_logs.csv") init_image = load_img(image['image'], Height, Width).to(device) - mask = load_mask(image['mask'], Height, Width, True).to(device) model.unet_bs = unet_bs model.turbo = turbo @@ -161,10 +123,7 @@ def generate( modelCS.half() modelFS.half() init_image = init_image.half() - mask.half() - - mask = mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) - mask = repeat(mask, '1 ... -> b ...', b=batch_size) + # mask.half() tic = time.time() os.makedirs(outdir, exist_ok=True) @@ -182,6 +141,10 @@ def generate( init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space init_latent = repeat(init_latent, "1 ... -> b ...", b=batch_size) + mask = load_mask(image['mask'], Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device) + mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) + mask = repeat(mask, '1 ... -> b ...', b=batch_size) + if device != "cpu": mem = torch.cuda.memory_allocated() / 1e6 modelFS.to("cpu") @@ -237,7 +200,7 @@ def generate( z_enc = model.stochastic_encode( init_latent, torch.tensor([t_enc] * batch_size).to(device), seed, ddim_eta, ddim_steps) - + # decode it samples_ddim = model.sample( t_enc, @@ -245,15 +208,14 @@ def generate( z_enc, unconditional_guidance_scale=scale, unconditional_conditioning=uc, - mask = mask, - x_T = init_latent, - sampler = sampler, + mask=mask, + x_T=init_latent, + sampler=sampler, ) modelFS.to(device) print("saving images") for i in range(batch_size): - x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) all_samples.append(x_sample.to("cpu")) @@ -284,37 +246,77 @@ def generate( grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() txt = ( - "Samples finished in " - + str(round(time_taken, 3)) - + " minutes and exported to \n" - + sample_path - + "\nSeeds used = " - + seeds[:-1] + "Samples finished in " + + str(round(time_taken, 3)) + + " minutes and exported to \n" + + sample_path + + "\nSeeds used = " + + seeds[:-1] + ) + return Image.fromarray(grid.astype(np.uint8)), image['mask'], txt + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='txt2img using gradio') + parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path') + parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path') + args = parser.parse_args() + config = args.config_path + ckpt = args.ckpt_path + sd = load_model_from_config(f"{ckpt}") + li, lo = [], [] + for key, v_ in sd.items(): + sp = key.split(".") + if (sp[0]) == "model": + if "input_blocks" in sp: + li.append(key) + elif "middle_block" in sp: + li.append(key) + elif "time_embed" in sp: + li.append(key) + else: + lo.append(key) + for key in li: + sd["model1." + key[6:]] = sd.pop(key) + for key in lo: + sd["model2." + key[6:]] = sd.pop(key) + + config = OmegaConf.load(f"{config}") + + model = instantiate_from_config(config.modelUNet) + _, _ = model.load_state_dict(sd, strict=False) + model.eval() + + modelCS = instantiate_from_config(config.modelCondStage) + _, _ = modelCS.load_state_dict(sd, strict=False) + modelCS.eval() + + modelFS = instantiate_from_config(config.modelFirstStage) + _, _ = modelFS.load_state_dict(sd, strict=False) + modelFS.eval() + del sd + + demo = gr.Interface( + fn=generate, + inputs=[ + gr.Image(tool="sketch", type="pil"), + "text", + gr.Slider(0, 0.99, value=0.99, step=0.01), + gr.Slider(1, 1000, value=50), + gr.Slider(1, 100, step=1), + gr.Slider(1, 100, step=1), + gr.Slider(64, 4096, value=512, step=64), + gr.Slider(64, 4096, value=512, step=64), + gr.Slider(0, 50, value=7.5, step=0.1), + gr.Slider(0, 1, step=0.01), + gr.Slider(1, 2, value=1, step=1), + gr.Text(value="cuda"), + "text", + gr.Text(value="outputs/inpaint-samples"), + gr.Radio(["png", "jpg"], value='png'), + "checkbox", + "checkbox", + ], + outputs=["image", "image", "text"], ) - return Image.fromarray(grid.astype(np.uint8)), image['mask'],txt - - -demo = gr.Interface( - fn=generate, - inputs=[ - gr.Image(tool="sketch", type="pil"), - "text", - gr.Slider(0, 0.99, value=0.99, step = 0.01), - gr.Slider(1, 1000, value=50), - gr.Slider(1, 100, step=1), - gr.Slider(1, 100, step=1), - gr.Slider(64, 4096, value=512, step=64), - gr.Slider(64, 4096, value=512, step=64), - gr.Slider(0, 50, value=7.5, step=0.1), - gr.Slider(0, 1, step=0.01), - gr.Slider(1, 2, value=1, step=1), - gr.Text(value="cuda"), - "text", - gr.Text(value="outputs/inpaint-samples"), - gr.Radio(["png", "jpg"], value='png'), - "checkbox", - "checkbox", - ], - outputs=["image", "image", "text"], -) -demo.launch() \ No newline at end of file + demo.launch()