Skip to content

Commit

Permalink
inpaint gradio mask mode fixed
Browse files Browse the repository at this point in the history
(cherry picked from commit ddde264)
  • Loading branch information
neonsecret committed Sep 4, 2022
1 parent 47f8784 commit f1bb248
Showing 1 changed file with 124 additions and 122 deletions.
246 changes: 124 additions & 122 deletions optimizedSD/inpaint_gradio.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import argparse
import os
import re
import time
from contextlib import nullcontext
from itertools import islice
from random import randint

import gradio as gr
import numpy as np
import torch
from torchvision.utils import make_grid
import os, re
from PIL import Image
import torch
import numpy as np
from random import randint
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from einops import rearrange, repeat
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from transformers import logging
import pandas as pd

from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger

logging.set_verbosity_error()
import mimetypes

mimetypes.init()
mimetypes.add_type("application/javascript", ".js")

Expand All @@ -43,7 +43,6 @@ def load_model_from_config(ckpt, verbose=False):


def load_img(image, h0, w0):

image = image.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
Expand All @@ -60,96 +59,59 @@ def load_img(image, h0, w0):
return 2.0 * image - 1.0


def load_mask(mask, h0, w0, invert=False):

def load_mask(mask, h0, w0, newH, newW, invert=False):
image = mask.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
if(h0 is not None and w0 is not None):
print(f"loaded input mask of size ({w}, {h})")
if h0 is not None and w0 is not None:
h, w = h0, w0

w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32

print(f"New image size ({w}, {h})")
image = image.resize((64, 64), resample = Image.LANCZOS)
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.LANCZOS)
# image = image.resize((64, 64), resample=Image.LANCZOS)
image = np.array(image)

if invert:
print("inverted")
where_0, where_1 = np.where(image == 0),np.where(image == 255)
where_0, where_1 = np.where(image == 0), np.where(image == 255)
image[where_0], image[where_1] = 255, 0
image = image.astype(np.float32)/255.0
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image


config = "optimizedSD/v1-inference.yaml"
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, v_ in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd

def generate(
image,
prompt,
strength,
ddim_steps,
n_iter,
batch_size,
Height,
Width,
scale,
ddim_eta,
unet_bs,
device,
seed,
outdir,
img_format,
turbo,
full_precision,
image,
prompt,
strength,
ddim_steps,
n_iter,
batch_size,
Height,
Width,
scale,
ddim_eta,
unet_bs,
device,
seed,
outdir,
img_format,
turbo,
full_precision,
):

if seed == "":
seed = randint(0, 1000000)
seed = int(seed)
seed_everything(seed)
sampler = "ddim"

# Logging
logger(locals(), log_csv = "logs/inpaint_gradio_logs.csv")
logger(locals(), log_csv="logs/inpaint_gradio_logs.csv")

init_image = load_img(image['image'], Height, Width).to(device)
mask = load_mask(image['mask'], Height, Width, True).to(device)

model.unet_bs = unet_bs
model.turbo = turbo
Expand All @@ -161,10 +123,7 @@ def generate(
modelCS.half()
modelFS.half()
init_image = init_image.half()
mask.half()

mask = mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
# mask.half()

tic = time.time()
os.makedirs(outdir, exist_ok=True)
Expand All @@ -182,6 +141,10 @@ def generate(
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
init_latent = repeat(init_latent, "1 ... -> b ...", b=batch_size)

mask = load_mask(image['mask'], Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)

if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
Expand Down Expand Up @@ -237,23 +200,22 @@ def generate(
z_enc = model.stochastic_encode(
init_latent, torch.tensor([t_enc] * batch_size).to(device),
seed, ddim_eta, ddim_steps)

# decode it
samples_ddim = model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
mask = mask,
x_T = init_latent,
sampler = sampler,
mask=mask,
x_T=init_latent,
sampler=sampler,
)

modelFS.to(device)
print("saving images")
for i in range(batch_size):

x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
all_samples.append(x_sample.to("cpu"))
Expand Down Expand Up @@ -284,37 +246,77 @@ def generate(
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()

txt = (
"Samples finished in "
+ str(round(time_taken, 3))
+ " minutes and exported to \n"
+ sample_path
+ "\nSeeds used = "
+ seeds[:-1]
"Samples finished in "
+ str(round(time_taken, 3))
+ " minutes and exported to \n"
+ sample_path
+ "\nSeeds used = "
+ seeds[:-1]
)
return Image.fromarray(grid.astype(np.uint8)), image['mask'], txt


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='txt2img using gradio')
parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path')
parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path')
args = parser.parse_args()
config = args.config_path
ckpt = args.ckpt_path
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, v_ in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd

demo = gr.Interface(
fn=generate,
inputs=[
gr.Image(tool="sketch", type="pil"),
"text",
gr.Slider(0, 0.99, value=0.99, step=0.01),
gr.Slider(1, 1000, value=50),
gr.Slider(1, 100, step=1),
gr.Slider(1, 100, step=1),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(0, 50, value=7.5, step=0.1),
gr.Slider(0, 1, step=0.01),
gr.Slider(1, 2, value=1, step=1),
gr.Text(value="cuda"),
"text",
gr.Text(value="outputs/inpaint-samples"),
gr.Radio(["png", "jpg"], value='png'),
"checkbox",
"checkbox",
],
outputs=["image", "image", "text"],
)
return Image.fromarray(grid.astype(np.uint8)), image['mask'],txt


demo = gr.Interface(
fn=generate,
inputs=[
gr.Image(tool="sketch", type="pil"),
"text",
gr.Slider(0, 0.99, value=0.99, step = 0.01),
gr.Slider(1, 1000, value=50),
gr.Slider(1, 100, step=1),
gr.Slider(1, 100, step=1),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(0, 50, value=7.5, step=0.1),
gr.Slider(0, 1, step=0.01),
gr.Slider(1, 2, value=1, step=1),
gr.Text(value="cuda"),
"text",
gr.Text(value="outputs/inpaint-samples"),
gr.Radio(["png", "jpg"], value='png'),
"checkbox",
"checkbox",
],
outputs=["image", "image", "text"],
)
demo.launch()
demo.launch()

0 comments on commit f1bb248

Please sign in to comment.