Skip to content

Commit

Permalink
nest 2 AND_* keywords or more inside [...] to group (#57)
Browse files Browse the repository at this point in the history
* nest 2 to group

* fix some corner cases with accidental double brackets

* refactor; fix cfg rescale mean

* tests

* more test cases

* more test cases

---------

Co-authored-by: ljleb <set>
  • Loading branch information
ljleb authored Jan 26, 2024
1 parent ccc2aed commit e395a11
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 142 deletions.
112 changes: 45 additions & 67 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def combine_denoised_hijack(

for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
cond_delta = prompt.accept(CondDeltaChildVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaChildVisitor(), args, cond_delta, 0)
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
denoised[batch_i] = cfg_cond * get_cfg_rescale_factor(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
denoised[batch_i] = cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)

return denoised

Expand All @@ -41,22 +41,27 @@ def get_webui_denoised(
uncond = x_out[-text_uncond.shape[0]:]
sliced_batch_x_out = []
sliced_batch_cond_indices = []
index_in = 0

for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, len(sliced_batch_x_out))
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, index_in, len(sliced_batch_x_out))
if sliced_cond_indices:
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_batch_x_out.extend(sliced_x_out)
index_in += prompt.accept(neutral_prompt_parser.FlatSizeVisitor())

sliced_batch_x_out += list(uncond)
sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0)
sliced_batch_cond_indices = [il for il in sliced_batch_cond_indices if il]
return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale)


def get_cfg_rescale_factor(cfg_cond, cond):
def cfg_rescale(cfg_cond, cond):
global_state.apply_and_clear_cfg_rescale_override()
return global_state.cfg_rescale * (torch.std(cond) / torch.std(cfg_cond) - 1) + 1
cfg_cond_mean = cfg_cond.mean()
cfg_resacle_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean()
cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1
return cfg_resacle_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor


@dataclasses.dataclass
Expand All @@ -68,67 +73,36 @@ class CombineDenoiseArgs:

@dataclasses.dataclass
class GatherWebuiCondsVisitor:
def visit_leaf_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [], []
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [args.x_out[args.cond_indices[index_in][0]]], [(index_out, args.cond_indices[index_in][1])]

def visit_composite_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_offset: int,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
sliced_x_out = []
sliced_cond_indices = []

index_in = 0
for child in that.children:
index_out = index_offset + len(sliced_x_out)
child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor.SingleCondVisitor(), args.x_out, args.cond_indices[index_in], index_out)
sliced_x_out.extend(child_x_out)
sliced_cond_indices.extend(child_cond_indices)
if child.conciliation is None:
index_offset = index_out + len(sliced_x_out)
child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor(), args, index_in, index_offset)
sliced_x_out.extend(child_x_out)
sliced_cond_indices.extend(child_cond_indices)

index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())

return sliced_x_out, sliced_cond_indices

@dataclasses.dataclass
class SingleCondVisitor:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.LeafPrompt,
x_out: torch.Tensor,
cond_info: Tuple[int, float],
index: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [x_out[cond_info[0]]], [(index, cond_info[1])]

def visit_composite_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [], []


@dataclasses.dataclass
class CondDeltaChildVisitor:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.LeafPrompt,
args: CombineDenoiseArgs,
index: int,
) -> torch.Tensor:
return torch.zeros_like(args.x_out[0])

def visit_composite_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index: int,
) -> torch.Tensor:
cond_delta = torch.zeros_like(args.x_out[0])

for child in that.children:
cond_delta += child.weight * child.accept(CondDeltaVisitor(), args, index)
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())

return cond_delta


@dataclasses.dataclass
class CondDeltaVisitor:
Expand All @@ -143,8 +117,9 @@ def visit_leaf_prompt(
console_warn(f'''
An unexpected noise weight was encountered at prompt #{index}
Expected :{that.weight}, but got :{cond_info[1]}
This is likely due to another extension also monkey patching the webui noise blending function
Please open a github issue so that the conflict can be resolved
This is likely due to another extension also monkey patching the webui `combine_denoised` function
Please open a bug report here so that the conflict can be resolved:
https://github.com/ljleb/sd-webui-neutral-prompt/issues
''')

return args.x_out[cond_info[0]] - args.uncond
Expand All @@ -157,17 +132,19 @@ def visit_composite_prompt(
) -> torch.Tensor:
cond_delta = torch.zeros_like(args.x_out[0])

if that.conciliation is None:
for child in that.children:
child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index)
child_cond_delta += child.accept(AuxCondDeltaChildVisitor(), args, child_cond_delta, index)
for child in that.children:
if child.conciliation is None:
child_cond_delta = child.accept(CondDeltaVisitor(), args, index)
child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index)
cond_delta += child.weight * child_cond_delta

index += child.accept(neutral_prompt_parser.FlatSizeVisitor())

return cond_delta


@dataclasses.dataclass
class AuxCondDeltaChildVisitor:
class AuxCondDeltaVisitor:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.LeafPrompt,
Expand All @@ -188,9 +165,10 @@ def visit_composite_prompt(
salient_cond_deltas = []

for child in that.children:
child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index)
child_cond_delta += child.accept(self, args, child_cond_delta, index)
if isinstance(child, neutral_prompt_parser.CompositePrompt):
if child.conciliation is not None:
child_cond_delta = child.accept(CondDeltaVisitor(), args, index)
child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index)

if child.conciliation == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
aux_cond_delta += child.weight * get_perpendicular_component(cond_delta, child_cond_delta)
elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
Expand Down Expand Up @@ -221,12 +199,12 @@ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]
The blended result combines `normal` and vector information in salient regions.
"""

salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, weight in vectors]
salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, _ in vectors]
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)

result = torch.zeros_like(normal)
for mask_i, (vector, weight) in enumerate(vectors, start=1):
vector_mask = ((mask == mask_i).float())
vector_mask = (mask == mask_i).float()
result += weight * vector_mask * (vector - normal)

return result
Expand Down
101 changes: 47 additions & 54 deletions lib_neutral_prompt/neutral_prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,6 @@
from typing import List, Tuple, Any, Optional


@dataclasses.dataclass
class PromptExpr(abc.ABC):
weight: float

@abc.abstractmethod
def accept(self, visitor, *args, **kwargs) -> Any:
pass


@dataclasses.dataclass
class LeafPrompt(PromptExpr):
prompt: str

def accept(self, visitor, *args, **kwargs):
return visitor.visit_leaf_prompt(self, *args, **kwargs)


class PromptKeyword(Enum):
AND = 'AND'
AND_PERP = 'AND_PERP'
Expand All @@ -41,10 +24,27 @@ class ConciliationStrategy(Enum):
conciliation_strategies = [e.value for e in ConciliationStrategy]


@dataclasses.dataclass
class PromptExpr(abc.ABC):
weight: float
conciliation: Optional[ConciliationStrategy]

@abc.abstractmethod
def accept(self, visitor, *args, **kwargs) -> Any:
pass


@dataclasses.dataclass
class LeafPrompt(PromptExpr):
prompt: str

def accept(self, visitor, *args, **kwargs):
return visitor.visit_leaf_prompt(self, *args, **kwargs)


@dataclasses.dataclass
class CompositePrompt(PromptExpr):
children: List[PromptExpr]
conciliation: Optional[ConciliationStrategy]

def accept(self, visitor, *args, **kwargs):
return visitor.visit_composite_prompt(self, *args, **kwargs)
Expand All @@ -61,57 +61,53 @@ def visit_composite_prompt(self, that: CompositePrompt) -> int:
def parse_root(string: str) -> CompositePrompt:
tokens = tokenize(string)
prompts = parse_prompts(tokens)
return CompositePrompt(1., prompts, None)
return CompositePrompt(1., None, prompts)


def parse_prompts(tokens: List[str]) -> List[PromptExpr]:
prompts = [parse_prompt(tokens, first=True)]
def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
prompts = [parse_prompt(tokens, first=True, nested=nested)]
while tokens:
if tokens[0] in [']']:
if nested and tokens[0] in [']']:
break

prompts.append(parse_prompt(tokens, first=False))
prompts.append(parse_prompt(tokens, first=False, nested=nested))

return prompts


def parse_prompt(tokens: List[str], *, first: bool) -> PromptExpr:
if first:
prompt_type = PromptKeyword.AND.value
else:
assert tokens[0] in prompt_keywords
def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
if not first and tokens[0] in prompt_keywords:
prompt_type = tokens.pop(0)
else:
prompt_type = PromptKeyword.AND.value

tokens_copy = tokens.copy()
if tokens_copy and tokens_copy[0] == '[':
tokens_copy.pop(0)
prompts = parse_prompts(tokens_copy)
if tokens_copy:
assert tokens_copy.pop(0) == ']'
if not tokens_copy or tokens_copy[0] in prompt_keywords + [']']:
tokens[:] = tokens_copy
weight = parse_weight(tokens)
conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None
return CompositePrompt(weight, prompts, conciliation)

prompt_text, weight = parse_prompt_text(tokens)
prompt = LeafPrompt(weight, prompt_text)
if prompt_type in conciliation_strategies:
prompt.weight = 1.
prompt = CompositePrompt(weight, [prompt], ConciliationStrategy(prompt_type))
tokens_copy = tokens.copy()
if tokens_copy and tokens_copy[0] == '[':
tokens_copy.pop(0)
prompts = parse_prompts(tokens_copy, nested=True)
if tokens_copy:
assert tokens_copy.pop(0) == ']'
if len(prompts) > 1:
tokens[:] = tokens_copy
weight = parse_weight(tokens)
conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None
return CompositePrompt(weight, conciliation, prompts)

return prompt
prompt_text, weight = parse_prompt_text(tokens, nested=nested)
return LeafPrompt(weight, ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None, prompt_text)


def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]:
def parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
text = ''
depth = 0
weight = 1.
while tokens:
if tokens[0] == ']':
if depth == 0:
break
depth -= 1
if nested:
break
else:
depth -= 1
elif tokens[0] == '[':
depth += 1
elif tokens[0] == ':':
Expand All @@ -130,12 +126,9 @@ def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]:

def parse_weight(tokens: List[str]) -> float:
weight = 1.
if tokens and tokens[0] == ':':
if len(tokens) >= 2 and tokens[0] == ':' and is_float(tokens[1]):
tokens.pop(0)
if tokens:
weight_str = tokens.pop(0)
if is_float(weight_str):
weight = float(weight_str)
weight = float(tokens.pop(0))
return weight


Expand Down
Loading

0 comments on commit e395a11

Please sign in to comment.