From ed3fcd81e5b902f28b74b63db5624cda46ff0435 Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:07:19 +0900 Subject: [PATCH] [Fix] Fix Regional Prompting Pipeline (#6188) * Update regional_prompting_stable_diffusion.py * reformat * reformat * reformat * reformat * reformat * reformat * reformat * regormat * reformat * reformat * reformat * reformat * Update regional_prompting_stable_diffusion.py --------- Co-authored-by: Sayak Paul --- .../regional_prompting_stable_diffusion.py | 75 +++++++++++++------ 1 file changed, 53 insertions(+), 22 deletions(-) diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 525e75bc68b93..71f24a81bd15c 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -73,7 +73,14 @@ def __init__( requires_safety_checker: bool = True, ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker, ) self.register_modules( vae=vae, @@ -102,22 +109,22 @@ def __call__( return_dict: bool = True, rp_args: Dict[str, str] = None, ): - active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721 + active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt if negative_prompt is None: - negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721 + negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt) device = self._execution_device regions = 0 self.power = int(rp_args["power"]) if "power" in rp_args else 1 - prompts = prompt if type(prompt) == list else [prompt] # noqa: E721 - n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721 + prompts = prompt if isinstance(prompt, list) else [prompt] + n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt] self.batch = batch = num_images_per_prompt * len(prompts) all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) - cn = len(all_prompts_cn) == len(all_n_prompts_cn) + equal = len(all_prompts_cn) == len(all_n_prompts_cn) if Compel: compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder) @@ -129,7 +136,7 @@ def getcompelembs(prps): return torch.cat(embl) conds = getcompelembs(all_prompts_cn) - unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts) + unconds = getcompelembs(all_n_prompts_cn) embs = getcompelembs(prompts) n_embs = getcompelembs(n_prompts) prompt = negative_prompt = None @@ -137,7 +144,7 @@ def getcompelembs(prps): conds = self.encode_prompt(prompts, device, 1, True)[0] unconds = ( self.encode_prompt(n_prompts, device, 1, True)[0] - if cn + if equal else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] ) embs = n_embs = None @@ -206,8 +213,11 @@ def forward( else: px, nx = hidden_states.chunk(2) - if cn: - hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0) + if equal: + hidden_states = torch.cat( + [px for i in range(regions)] + [nx for i in range(regions)], + 0, + ) encoder_hidden_states = torch.cat([conds] + [unconds]) else: hidden_states = torch.cat([px for i in range(regions)] + [nx], 0) @@ -289,9 +299,9 @@ def forward( if any(x in mode for x in ["COL", "ROW"]): reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2]) center = reshaped.shape[0] // 2 - px = reshaped[0:center] if cn else reshaped[0:-batch] - nx = reshaped[center:] if cn else reshaped[-batch:] - outs = [px, nx] if cn else [px] + px = reshaped[0:center] if equal else reshaped[0:-batch] + nx = reshaped[center:] if equal else reshaped[-batch:] + outs = [px, nx] if equal else [px] for out in outs: c = 0 for i, ocell in enumerate(ocells): @@ -321,15 +331,16 @@ def forward( :, ] c += 1 - px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) + px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) hidden_states = hidden_states.reshape(xshape) #### Regional Prompting Prompt mode elif "PRO" in mode: - center = reshaped.shape[0] // 2 - px = reshaped[0:center] if cn else reshaped[0:-batch] - nx = reshaped[center:] if cn else reshaped[-batch:] + px, nx = ( + torch.chunk(hidden_states) if equal else hidden_states[0:-batch], + hidden_states[-batch:], + ) if (h, w) in self.attnmasks and self.maskready: @@ -340,8 +351,8 @@ def mask(input): out[b] = out[b] + out[r * batch + b] return out - px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx) - px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) + px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx) + px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) return hidden_states @@ -378,7 +389,15 @@ def hook_forwards(root_module: torch.nn.Module): save_mask = False if mode == "PROMPT" and save_mask: - saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions) + saveattnmaps( + self, + output, + height, + width, + thresholds, + num_inference_steps // 2, + regions, + ) return output @@ -437,7 +456,11 @@ def startend(cells, array): def make_emblist(self, prompts): with torch.no_grad(): tokens = self.tokenizer( - prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" + prompts, + max_length=self.tokenizer.model_max_length, + padding=True, + truncation=True, + return_tensors="pt", ).input_ids.to(self.device) embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype) return embs @@ -563,7 +586,15 @@ def tokendealer(self, all_prompts): def scaled_dot_product_attention( - self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + getattn=False, ) -> torch.Tensor: # Efficient implementation equivalent to the following: L, S = query.size(-2), key.size(-2)