diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 24a50d9b9011..525e75bc68b9 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -1,33 +1,37 @@ -import torchvision.transforms.functional as FF -import torch -import torchvision +import math from typing import Dict, Optional + +import torch +import torchvision.transforms.functional as FF +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + from diffusers import StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.utils import USE_PEFT_BACKEND from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from diffusers.utils import USE_PEFT_BACKEND + try: from compel import Compel -except: +except ImportError: Compel = None KCOMM = "ADDCOMM" KBRK = "BREAK" + class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): r""" Args for Regional Prompting Pipeline: rp_args:dict - Required + Required rp_args["mode"]: cols, rows, prompt, prompt-ex for cols, rows mode rp_args["div"]: ex) 1;1;1(Divide into 3 regions) - for prompt, prompt-ex mode + for prompt, prompt-ex mode rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode) - + Optional rp_args["save_mask"]: True/False (save masks in prompt mode) @@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + def __init__( self, vae: AutoencoderKL, @@ -67,7 +72,9 @@ def __init__( feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, ): - super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker) + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) self.register_modules( vae=vae, text_encoder=text_encoder, @@ -93,50 +100,56 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - rp_args:Dict[str,str] = None, + rp_args: Dict[str, str] = None, ): - - active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt - if negative_prompt is None: negative_prompt = "" if type(prompt) == str else [""] * len(prompt) + active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721 + if negative_prompt is None: + negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721 device = self._execution_device regions = 0 - + self.power = int(rp_args["power"]) if "power" in rp_args else 1 - prompts = prompt if type(prompt) == list else [prompt] - n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] + prompts = prompt if type(prompt) == list else [prompt] # noqa: E721 + n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721 self.batch = batch = num_images_per_prompt * len(prompts) - all_prompts_cn, all_prompts_p = promptsmaker(prompts,num_images_per_prompt) - all_n_prompts_cn, _ = promptsmaker(n_prompts,num_images_per_prompt) + all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) + all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) cn = len(all_prompts_cn) == len(all_n_prompts_cn) if Compel: compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder) + def getcompelembs(prps): embl = [] for prp in prps: embl.append(compel.build_conditioning_tensor(prp)) return torch.cat(embl) + conds = getcompelembs(all_prompts_cn) - unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts) + unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts) embs = getcompelembs(prompts) n_embs = getcompelembs(n_prompts) prompt = negative_prompt = None else: conds = self.encode_prompt(prompts, device, 1, True)[0] - unconds = self.encode_prompt(n_prompts, device, 1, True)[0] if cn else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] + unconds = ( + self.encode_prompt(n_prompts, device, 1, True)[0] + if cn + else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] + ) embs = n_embs = None if not active: pcallback = None mode = None else: - if any(x in rp_args["mode"].upper() for x in ["COL","ROW"]): - mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW" - ocells,icells,regions = make_cells(rp_args["div"]) - + if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]): + mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW" + ocells, icells, regions = make_cells(rp_args["div"]) + elif "PRO" in rp_args["mode"].upper(): regions = len(all_prompts_p[0]) mode = "PROMPT" @@ -144,14 +157,14 @@ def getcompelembs(prps): self.ex = "EX" in rp_args["mode"].upper() self.target_tokens = target_tokens = tokendealer(self, all_prompts_p) thresholds = [float(x) for x in rp_args["th"].split(",")] - - orig_hw = (height,width) + + orig_hw = (height, width) revers = True - def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor,selfs=None): + def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None): if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps self.step = step - + if len(self.attnmaps_sizes) > 3: self.history[step] = self.attnmaps.copy() for hw in self.attnmaps_sizes: @@ -167,7 +180,7 @@ def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor,selfs allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]] allmasks.append(mask) basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask - basemasks = [1 -mask for mask in basemasks] + basemasks = [1 - mask for mask in basemasks] basemasks = [torch.where(x > 0, 1, 0) for x in basemasks] allmasks = basemasks + allmasks @@ -176,7 +189,7 @@ def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor,selfs return latents def hook_forward(module): - #diffusers==0.23.2 + # diffusers==0.23.2 def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, @@ -184,22 +197,21 @@ def forward( temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, ) -> torch.Tensor: - - attn = module + attn = module xshape = hidden_states.shape - self.hw = (h,w) = split_dims(xshape[1], *orig_hw) + self.hw = (h, w) = split_dims(xshape[1], *orig_hw) if revers: - nx,px = hidden_states.chunk(2) + nx, px = hidden_states.chunk(2) else: - px,nx = hidden_states.chunk(2) + px, nx = hidden_states.chunk(2) if cn: - hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)],0) - encoder_hidden_states = torch.cat([conds]+[unconds]) + hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0) + encoder_hidden_states = torch.cat([conds] + [unconds]) else: - hidden_states = torch.cat([px for i in range(regions)] + [nx],0) - encoder_hidden_states = torch.cat([conds]+[unconds]) + hidden_states = torch.cat([px for i in range(regions)] + [nx], 0) + encoder_hidden_states = torch.cat([conds] + [unconds]) residual = hidden_states @@ -247,12 +259,19 @@ def forward( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = scaled_dot_product_attention( - self, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, getattn = "PRO" in mode + self, + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + getattn="PRO" in mode, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - + # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout @@ -272,18 +291,38 @@ def forward( center = reshaped.shape[0] // 2 px = reshaped[0:center] if cn else reshaped[0:-batch] nx = reshaped[center:] if cn else reshaped[-batch:] - outs = [px,nx] if cn else [px] + outs = [px, nx] if cn else [px] for out in outs: c = 0 - for i,ocell in enumerate(ocells): + for i, ocell in enumerate(ocells): for icell in icells[i]: if "ROW" in mode: - out[0:batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] = out[c*batch:(c+1)*batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] + out[ + 0:batch, + int(h * ocell[0]) : int(h * ocell[1]), + int(w * icell[0]) : int(w * icell[1]), + :, + ] = out[ + c * batch : (c + 1) * batch, + int(h * ocell[0]) : int(h * ocell[1]), + int(w * icell[0]) : int(w * icell[1]), + :, + ] else: - out[0:batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] = out[c*batch:(c+1)*batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] + out[ + 0:batch, + int(h * icell[0]) : int(h * icell[1]), + int(w * ocell[0]) : int(w * ocell[1]), + :, + ] = out[ + c * batch : (c + 1) * batch, + int(h * icell[0]) : int(h * icell[1]), + int(w * ocell[0]) : int(w * ocell[1]), + :, + ] c += 1 px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) - hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0) + hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) hidden_states = hidden_states.reshape(xshape) #### Regional Prompting Prompt mode @@ -291,17 +330,19 @@ def forward( center = reshaped.shape[0] // 2 px = reshaped[0:center] if cn else reshaped[0:-batch] nx = reshaped[center:] if cn else reshaped[-batch:] - - if (h,w) in self.attnmasks and self.maskready: + + if (h, w) in self.attnmasks and self.maskready: + def mask(input): - out = torch.multiply(input,self.attnmasks[(h,w)]) + out = torch.multiply(input, self.attnmasks[(h, w)]) for b in range(batch): for r in range(1, regions): out[b] = out[b] + out[r * batch + b] return out + px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx) px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) - hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0) + hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) return hidden_states return forward @@ -328,7 +369,7 @@ def hook_forwards(root_module: torch.nn.Module): latents=latents, output_type=output_type, return_dict=return_dict, - callback_on_step_end = pcallback + callback_on_step_end=pcallback, ) if "save_mask" in rp_args: @@ -336,13 +377,14 @@ def hook_forwards(root_module: torch.nn.Module): else: save_mask = False - if mode == "PROMPT" and save_mask: saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions) + if mode == "PROMPT" and save_mask: + saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions) return output ### Make prompt list for each regions -def promptsmaker(prompts,batch): +def promptsmaker(prompts, batch): out_p = [] plen = len(prompts) for prompt in prompts: @@ -352,24 +394,26 @@ def promptsmaker(prompts,batch): add = add + " " prompts = prompt.split(KBRK) out_p.append([add + p for p in prompts]) - out = [None]*batch*len(out_p[0]) * len(out_p) - for p, prs in enumerate(out_p): # inputs prompts - for r, pr in enumerate(prs): # prompts for regions + out = [None] * batch * len(out_p[0]) * len(out_p) + for p, prs in enumerate(out_p): # inputs prompts + for r, pr in enumerate(prs): # prompts for regions start = (p + r * plen) * batch - out[start : start + batch]= [pr] * batch #P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1... + out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1... return out, out_p + ### make regions from ratios ### ";" makes outercells, "," makes inner cells def make_cells(ratios): - if ";" not in ratios and "," in ratios:ratios = ratios.replace(",",";") + if ";" not in ratios and "," in ratios: + ratios = ratios.replace(",", ";") ratios = ratios.split(";") ratios = [inratios.split(",") for inratios in ratios] icells = [] ocells = [] - def startend(cells,array): + def startend(cells, array): current_start = 0 array = [float(x) for x in array] for value in array: @@ -377,72 +421,80 @@ def startend(cells,array): cells.append([current_start, end]) current_start = end - startend(ocells,[r[0] for r in ratios]) + startend(ocells, [r[0] for r in ratios]) for inratios in ratios: if 2 > len(inratios): - icells.append([[0,1]]) + icells.append([[0, 1]]) else: add = [] - startend(add,inratios[1:]) + startend(add, inratios[1:]) icells.append(add) return ocells, icells, sum(len(cell) for cell in icells) + def make_emblist(self, prompts): with torch.no_grad(): - tokens = self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids.to(self.device) - embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype = self.dtype) + tokens = self.tokenizer( + prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" + ).input_ids.to(self.device) + embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype) return embs -import math + def split_dims(xs, height, width): xs = xs - def repeat_div(x,y): + + def repeat_div(x, y): while y > 0: x = math.ceil(x / 2) y = y - 1 return x + scale = math.ceil(math.log2(math.sqrt(height * width / xs))) - dsh = repeat_div(height,scale) - dsw = repeat_div(width,scale) - return dsh,dsw + dsh = repeat_div(height, scale) + dsw = repeat_div(width, scale) + return dsh, dsw + ##### for prompt mode -def get_attn_maps(self,attn): - height,width = self.hw +def get_attn_maps(self, attn): + height, width = self.hw target_tokens = self.target_tokens - if (height,width) not in self.attnmaps_sizes: - self.attnmaps_sizes.append((height,width)) - + if (height, width) not in self.attnmaps_sizes: + self.attnmaps_sizes.append((height, width)) + for b in range(self.batch): for t in target_tokens: power = self.power - add = attn[b,:,:,t[0]:t[0]+len(t)]**(power)*(self.attnmaps_sizes.index((height,width)) + 1) - add = torch.sum(add,dim = 2) - key = f"{t}-{b}" + add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1) + add = torch.sum(add, dim=2) + key = f"{t}-{b}" if key not in self.attnmaps: self.attnmaps[key] = add else: if self.attnmaps[key].shape[1] != add.shape[1]: - add = add.view(8,height,width) - add = FF.resize(add,self.attnmaps_sizes[0],antialias=None) + add = add.view(8, height, width) + add = FF.resize(add, self.attnmaps_sizes[0], antialias=None) add = add.reshape_as(self.attnmaps[key]) self.attnmaps[key] = self.attnmaps[key] + add -def reset_attnmaps(self): # init parameters in every batch + +def reset_attnmaps(self): # init parameters in every batch self.step = 0 - self.attnmaps = {} #maked from attention maps - self.attnmaps_sizes =[] #height,width set of u-net blocks - self.attnmasks = {} #maked from attnmaps for regions + self.attnmaps = {} # maked from attention maps + self.attnmaps_sizes = [] # height,width set of u-net blocks + self.attnmasks = {} # maked from attnmaps for regions self.maskready = False self.history = {} -def saveattnmaps(self,output,h,w,th,step,regions): + +def saveattnmaps(self, output, h, w, th, step, regions): masks = [] for i, mask in enumerate(self.history[step].values()): - img, _ , mask = makepmask(self, mask, h, w, th[i % len(th)], step) + img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step) if self.ex: masks = [x - mask for x in masks] masks.append(mask) @@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions): else: output.images.append(img) -def makepmask(self, mask, h, w, th, step): # make masks from attention cache return [for preview, for attention, for Latent] + +def makepmask( + self, mask, h, w, th, step +): # make masks from attention cache return [for preview, for attention, for Latent] th = th - step * 0.005 - if 0.05 >= th: th = 0.05 - mask = torch.mean(mask,dim=0) + if 0.05 >= th: + th = 0.05 + mask = torch.mean(mask, dim=0) mask = mask / mask.max().item() - mask = torch.where(mask > th ,1,0) + mask = torch.where(mask > th, 1, 0) mask = mask.float() - mask = mask.view(1,*self.attnmaps_sizes[0]) + mask = mask.view(1, *self.attnmaps_sizes[0]) img = FF.to_pil_image(mask) - img = img.resize((w,h)) - mask = FF.resize(mask,(h,w),interpolation=FF.InterpolationMode.NEAREST,antialias=None) + img = img.resize((w, h)) + mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None) lmask = mask - mask = mask.reshape(h*w) - mask = torch.where(mask > 0.1 ,1,0) + mask = mask.reshape(h * w) + mask = torch.where(mask > 0.1, 1, 0) return img, mask, lmask + def tokendealer(self, all_prompts): for prompts in all_prompts: - targets =[p.split(",")[-1] for p in prompts[1:]] + targets = [p.split(",")[-1] for p in prompts[1:]] tt = [] for target in targets: - ptokens = (self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0] - ttokens = (self.tokenizer(target, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0] + ptokens = ( + self.tokenizer( + prompts, + max_length=self.tokenizer.model_max_length, + padding=True, + truncation=True, + return_tensors="pt", + ).input_ids + )[0] + ttokens = ( + self.tokenizer( + target, + max_length=self.tokenizer.model_max_length, + padding=True, + truncation=True, + return_tensors="pt", + ).input_ids + )[0] tlist = [] - for t in range(ttokens.shape[0] -2): + for t in range(ttokens.shape[0] - 2): for p in range(ptokens.shape[0]): if ttokens[t + 1] == ptokens[p]: tlist.append(p) - if tlist != [] : tt.append(tlist) + if tlist != []: + tt.append(tlist) return tt -def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn = False) -> torch.Tensor: + +def scaled_dot_product_attention( + self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False +) -> torch.Tensor: # Efficient implementation equivalent to the following: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype,device=self.device) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) @@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) - if getattn: get_attn_maps(self,attn_weight) + if getattn: + get_attn_maps(self, attn_weight) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return attn_weight @ value \ No newline at end of file + return attn_weight @ value