From 69c83d6eed53ef22cde930247c1693ac26d602a4 Mon Sep 17 00:00:00 2001 From: cjkangme Date: Thu, 28 Nov 2024 20:24:23 +0900 Subject: [PATCH] [Community Pipeline] Add some feature for regional prompting pipeline (#9874) * [Fix] fix bugs of regional_prompting pipeline * [Feat] add base prompt feature * [Fix] fix __init__ pipeline error * [Fix] delete unused args * [Fix] improve string handling * [Docs] docs to use_base in regional_prompting * make style --------- Co-authored-by: Sayak Paul --- examples/community/README.md | 15 ++++ .../regional_prompting_stable_diffusion.py | 79 +++++++++++++++---- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 3eb5fc424b1d..ac8a13d40a97 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -3379,6 +3379,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK best quality, 3persons in garden, an old man red suit ``` +### Use base prompt + +You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first. + +``` +2d animation style ADDBASE +masterpiece, high quality ADDCOMM +(blue sky)++ BREAK +green hair twintail BREAK +book shelf BREAK +messy desk BREAK +orange++ dress and sofa +``` + ### Negative prompt Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions. @@ -3409,6 +3423,7 @@ pipe(prompt=prompt, rp_args=rp_args) ### Optional Parameters - `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`. +- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT` The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed. diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py index 8a022987ba9d..95f6cebb0190 100644 --- a/examples/community/regional_prompting_stable_diffusion.py +++ b/examples/community/regional_prompting_stable_diffusion.py @@ -3,13 +3,12 @@ import torch import torchvision.transforms.functional as FF -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers import StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import USE_PEFT_BACKEND try: @@ -17,6 +16,7 @@ except ImportError: Compel = None +KBASE = "ADDBASE" KCOMM = "ADDCOMM" KBRK = "BREAK" @@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): Optional rp_args["save_mask"]: True/False (save masks in prompt mode) + rp_args["power"]: int (power for attention maps in prompt mode) + rp_args["base_ratio"]: + float (Sets the ratio of the base prompt) + ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT) + [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt) Pipeline for text-to-image generation using Stable Diffusion. @@ -70,6 +75,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__( @@ -80,6 +86,7 @@ def __init__( scheduler, safety_checker, feature_extractor, + image_encoder, requires_safety_checker, ) self.register_modules( @@ -90,6 +97,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) @torch.no_grad() @@ -110,17 +118,40 @@ def __call__( rp_args: Dict[str, str] = None, ): active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt + use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt if negative_prompt is None: negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt) device = self._execution_device regions = 0 + self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0 self.power = int(rp_args["power"]) if "power" in rp_args else 1 prompts = prompt if isinstance(prompt, list) else [prompt] - n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt] + n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt] self.batch = batch = num_images_per_prompt * len(prompts) + + if use_base: + bases = prompts.copy() + n_bases = n_prompts.copy() + + for i, prompt in enumerate(prompts): + parts = prompt.split(KBASE) + if len(parts) == 2: + bases[i], prompts[i] = parts + elif len(parts) > 2: + raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}") + for i, prompt in enumerate(n_prompts): + n_parts = prompt.split(KBASE) + if len(n_parts) == 2: + n_bases[i], n_prompts[i] = n_parts + elif len(n_parts) > 2: + raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}") + + all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt) + all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt) + all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) @@ -137,8 +168,16 @@ def getcompelembs(prps): conds = getcompelembs(all_prompts_cn) unconds = getcompelembs(all_n_prompts_cn) - embs = getcompelembs(prompts) - n_embs = getcompelembs(n_prompts) + base_embs = getcompelembs(all_bases_cn) if use_base else None + base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None + # When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts + embs = getcompelembs(prompts) if not use_base else base_embs + n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs + + if use_base and self.base_ratio > 0: + conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds + unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds + prompt = negative_prompt = None else: conds = self.encode_prompt(prompts, device, 1, True)[0] @@ -147,6 +186,18 @@ def getcompelembs(prps): if equal else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] ) + + if use_base and self.base_ratio > 0: + base_embs = self.encode_prompt(bases, device, 1, True)[0] + base_n_embs = ( + self.encode_prompt(n_bases, device, 1, True)[0] + if equal + else self.encode_prompt(all_n_bases_cn, device, 1, True)[0] + ) + + conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds + unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds + embs = n_embs = None if not active: @@ -225,8 +276,6 @@ def forward( residual = hidden_states - args = () if USE_PEFT_BACKEND else (scale,) - if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -247,16 +296,15 @@ def forward( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -283,7 +331,7 @@ def forward( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -410,9 +458,9 @@ def promptsmaker(prompts, batch): add = "" if KCOMM in prompt: add, prompt = prompt.split(KCOMM) - add = add + " " - prompts = prompt.split(KBRK) - out_p.append([add + p for p in prompts]) + add = add.strip() + " " + prompts = [p.strip() for p in prompt.split(KBRK)] + out_p.append([add + p for i, p in enumerate(prompts)]) out = [None] * batch * len(out_p[0]) * len(out_p) for p, prs in enumerate(out_p): # inputs prompts for r, pr in enumerate(prs): # prompts for regions @@ -449,7 +497,6 @@ def startend(cells, array): add = [] startend(add, inratios[1:]) icells.append(add) - return ocells, icells, sum(len(cell) for cell in icells)