Skip to content

Commit

Permalink
dont use delta-space when possible (#62)
Browse files Browse the repository at this point in the history
* dont use delta-space when possible

* type

---------

Co-authored-by: ljleb <set>
  • Loading branch information
ljleb authored Jan 28, 2024
1 parent 1cd847a commit 185638f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
49 changes: 25 additions & 24 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_webui_denoised(

for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, 0, len(sliced_batch_x_out))
sliced_x_out, sliced_cond_indices = gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
if sliced_cond_indices:
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_batch_x_out.extend(sliced_x_out)
Expand All @@ -55,11 +55,14 @@ def get_webui_denoised(


def cfg_rescale(cfg_cond, cond):
if global_state.cfg_rescale == 0:
return cfg_cond

global_state.apply_and_clear_cfg_rescale_override()
cfg_cond_mean = cfg_cond.mean()
cfg_resacle_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean()
cfg_rescale_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean()
cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1
return cfg_resacle_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor
return cfg_rescale_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor


@dataclasses.dataclass
Expand All @@ -69,33 +72,32 @@ class CombineDenoiseArgs:
cond_indices: List[Tuple[int, float]]


@dataclasses.dataclass
class GatherWebuiCondsVisitor:
def visit_composite_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
sliced_x_out = []
sliced_cond_indices = []

for child in that.children:
if child.conciliation is None:
index_offset = index_out + len(sliced_x_out)
def gather_webui_conds(
prompt: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
sliced_x_out = []
sliced_cond_indices = []

for child in prompt.children:
if child.conciliation is None:
if isinstance(child, neutral_prompt_parser.LeafPrompt):
child_x_out = args.x_out[index_in]
else:
child_x_out = child.accept(CondDeltaVisitor(), args, index_in)
child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in)
child_x_out += args.uncond
sliced_x_out.append(child_x_out)
sliced_cond_indices.append((index_offset, child.weight))
index_offset = index_out + len(sliced_x_out)
sliced_x_out.append(child_x_out)
sliced_cond_indices.append((index_offset, child.weight))

index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())

return sliced_x_out, sliced_cond_indices
return sliced_x_out, sliced_cond_indices


@dataclasses.dataclass
class CondDeltaVisitor:
def visit_leaf_prompt(
self,
Expand Down Expand Up @@ -134,7 +136,6 @@ def visit_composite_prompt(
return cond_delta


@dataclasses.dataclass
class AuxCondDeltaVisitor:
def visit_leaf_prompt(
self,
Expand Down
2 changes: 1 addition & 1 deletion lib_neutral_prompt/prompt_parser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_multicond_prompt_list_hijack(prompts, original_function):
return original_function(webui_prompts)


def parse_prompts(prompts: List[str]) -> neutral_prompt_parser.PromptExpr:
def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
exprs = []
for prompt in prompts:
expr = neutral_prompt_parser.parse_root(prompt)
Expand Down

0 comments on commit 185638f

Please sign in to comment.