Skip to content

Commit

Permalink
Merge pull request basujindal#117 from neonsecret/basujindal_attn
Browse files Browse the repository at this point in the history
Memory-efficient attention and gradio mask fixed
  • Loading branch information
basujindal authored Sep 5, 2022
2 parents 1857272 + f1bb248 commit c56b493
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 129 deletions.
18 changes: 11 additions & 7 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,27 @@ def forward(self, x, context=None, mask=None):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40)
del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
del mask

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
# attention, what we cannot get enough of, by halves
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
sim = einsum('b i j, b j d -> b i d', sim, v)
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)


class BasicTransformerBlock(nn.Module):
Expand Down Expand Up @@ -258,4 +262,4 @@ def forward(self, x, context=None):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in
246 changes: 124 additions & 122 deletions optimizedSD/inpaint_gradio.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import argparse
import os
import re
import time
from contextlib import nullcontext
from itertools import islice
from random import randint

import gradio as gr
import numpy as np
import torch
from torchvision.utils import make_grid
import os, re
from PIL import Image
import torch
import numpy as np
from random import randint
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from einops import rearrange, repeat
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from transformers import logging
import pandas as pd

from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger

logging.set_verbosity_error()
import mimetypes

mimetypes.init()
mimetypes.add_type("application/javascript", ".js")

Expand All @@ -43,7 +43,6 @@ def load_model_from_config(ckpt, verbose=False):


def load_img(image, h0, w0):

image = image.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
Expand All @@ -60,96 +59,59 @@ def load_img(image, h0, w0):
return 2.0 * image - 1.0


def load_mask(mask, h0, w0, invert=False):

def load_mask(mask, h0, w0, newH, newW, invert=False):
image = mask.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
if(h0 is not None and w0 is not None):
print(f"loaded input mask of size ({w}, {h})")
if h0 is not None and w0 is not None:
h, w = h0, w0

w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32

print(f"New image size ({w}, {h})")
image = image.resize((64, 64), resample = Image.LANCZOS)
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.LANCZOS)
# image = image.resize((64, 64), resample=Image.LANCZOS)
image = np.array(image)

if invert:
print("inverted")
where_0, where_1 = np.where(image == 0),np.where(image == 255)
where_0, where_1 = np.where(image == 0), np.where(image == 255)
image[where_0], image[where_1] = 255, 0
image = image.astype(np.float32)/255.0
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image


config = "optimizedSD/v1-inference.yaml"
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, v_ in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd

def generate(
image,
prompt,
strength,
ddim_steps,
n_iter,
batch_size,
Height,
Width,
scale,
ddim_eta,
unet_bs,
device,
seed,
outdir,
img_format,
turbo,
full_precision,
image,
prompt,
strength,
ddim_steps,
n_iter,
batch_size,
Height,
Width,
scale,
ddim_eta,
unet_bs,
device,
seed,
outdir,
img_format,
turbo,
full_precision,
):

if seed == "":
seed = randint(0, 1000000)
seed = int(seed)
seed_everything(seed)
sampler = "ddim"

# Logging
logger(locals(), log_csv = "logs/inpaint_gradio_logs.csv")
logger(locals(), log_csv="logs/inpaint_gradio_logs.csv")

init_image = load_img(image['image'], Height, Width).to(device)
mask = load_mask(image['mask'], Height, Width, True).to(device)

model.unet_bs = unet_bs
model.turbo = turbo
Expand All @@ -161,10 +123,7 @@ def generate(
modelCS.half()
modelFS.half()
init_image = init_image.half()
mask.half()

mask = mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
# mask.half()

tic = time.time()
os.makedirs(outdir, exist_ok=True)
Expand All @@ -182,6 +141,10 @@ def generate(
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
init_latent = repeat(init_latent, "1 ... -> b ...", b=batch_size)

mask = load_mask(image['mask'], Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)

if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
Expand Down Expand Up @@ -237,23 +200,22 @@ def generate(
z_enc = model.stochastic_encode(
init_latent, torch.tensor([t_enc] * batch_size).to(device),
seed, ddim_eta, ddim_steps)

# decode it
samples_ddim = model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
mask = mask,
x_T = init_latent,
sampler = sampler,
mask=mask,
x_T=init_latent,
sampler=sampler,
)

modelFS.to(device)
print("saving images")
for i in range(batch_size):

x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
all_samples.append(x_sample.to("cpu"))
Expand Down Expand Up @@ -284,37 +246,77 @@ def generate(
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()

txt = (
"Samples finished in "
+ str(round(time_taken, 3))
+ " minutes and exported to \n"
+ sample_path
+ "\nSeeds used = "
+ seeds[:-1]
"Samples finished in "
+ str(round(time_taken, 3))
+ " minutes and exported to \n"
+ sample_path
+ "\nSeeds used = "
+ seeds[:-1]
)
return Image.fromarray(grid.astype(np.uint8)), image['mask'], txt


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='txt2img using gradio')
parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path')
parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path')
args = parser.parse_args()
config = args.config_path
ckpt = args.ckpt_path
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, v_ in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd

demo = gr.Interface(
fn=generate,
inputs=[
gr.Image(tool="sketch", type="pil"),
"text",
gr.Slider(0, 0.99, value=0.99, step=0.01),
gr.Slider(1, 1000, value=50),
gr.Slider(1, 100, step=1),
gr.Slider(1, 100, step=1),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(0, 50, value=7.5, step=0.1),
gr.Slider(0, 1, step=0.01),
gr.Slider(1, 2, value=1, step=1),
gr.Text(value="cuda"),
"text",
gr.Text(value="outputs/inpaint-samples"),
gr.Radio(["png", "jpg"], value='png'),
"checkbox",
"checkbox",
],
outputs=["image", "image", "text"],
)
return Image.fromarray(grid.astype(np.uint8)), image['mask'],txt


demo = gr.Interface(
fn=generate,
inputs=[
gr.Image(tool="sketch", type="pil"),
"text",
gr.Slider(0, 0.99, value=0.99, step = 0.01),
gr.Slider(1, 1000, value=50),
gr.Slider(1, 100, step=1),
gr.Slider(1, 100, step=1),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(0, 50, value=7.5, step=0.1),
gr.Slider(0, 1, step=0.01),
gr.Slider(1, 2, value=1, step=1),
gr.Text(value="cuda"),
"text",
gr.Text(value="outputs/inpaint-samples"),
gr.Radio(["png", "jpg"], value='png'),
"checkbox",
"checkbox",
],
outputs=["image", "image", "text"],
)
demo.launch()
demo.launch()

0 comments on commit c56b493

Please sign in to comment.