diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index 4256654..d032530 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -44,7 +44,7 @@ def get_webui_denoised( for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)): args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices) - sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, 0, len(sliced_batch_x_out)) + sliced_x_out, sliced_cond_indices = gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out)) if sliced_cond_indices: sliced_batch_cond_indices.append(sliced_cond_indices) sliced_batch_x_out.extend(sliced_x_out) @@ -55,11 +55,14 @@ def get_webui_denoised( def cfg_rescale(cfg_cond, cond): + if global_state.cfg_rescale == 0: + return cfg_cond + global_state.apply_and_clear_cfg_rescale_override() cfg_cond_mean = cfg_cond.mean() - cfg_resacle_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean() + cfg_rescale_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean() cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1 - return cfg_resacle_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor + return cfg_rescale_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor @dataclasses.dataclass @@ -69,33 +72,32 @@ class CombineDenoiseArgs: cond_indices: List[Tuple[int, float]] -@dataclasses.dataclass -class GatherWebuiCondsVisitor: - def visit_composite_prompt( - self, - that: neutral_prompt_parser.CompositePrompt, - args: CombineDenoiseArgs, - index_in: int, - index_out: int, - ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: - sliced_x_out = [] - sliced_cond_indices = [] - - for child in that.children: - if child.conciliation is None: - index_offset = index_out + len(sliced_x_out) +def gather_webui_conds( + prompt: neutral_prompt_parser.CompositePrompt, + args: CombineDenoiseArgs, + index_in: int, + index_out: int, +) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: + sliced_x_out = [] + sliced_cond_indices = [] + + for child in prompt.children: + if child.conciliation is None: + if isinstance(child, neutral_prompt_parser.LeafPrompt): + child_x_out = args.x_out[index_in] + else: child_x_out = child.accept(CondDeltaVisitor(), args, index_in) child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in) child_x_out += args.uncond - sliced_x_out.append(child_x_out) - sliced_cond_indices.append((index_offset, child.weight)) + index_offset = index_out + len(sliced_x_out) + sliced_x_out.append(child_x_out) + sliced_cond_indices.append((index_offset, child.weight)) - index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor()) + index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor()) - return sliced_x_out, sliced_cond_indices + return sliced_x_out, sliced_cond_indices -@dataclasses.dataclass class CondDeltaVisitor: def visit_leaf_prompt( self, @@ -134,7 +136,6 @@ def visit_composite_prompt( return cond_delta -@dataclasses.dataclass class AuxCondDeltaVisitor: def visit_leaf_prompt( self, diff --git a/lib_neutral_prompt/prompt_parser_hijack.py b/lib_neutral_prompt/prompt_parser_hijack.py index 423c233..fff7748 100644 --- a/lib_neutral_prompt/prompt_parser_hijack.py +++ b/lib_neutral_prompt/prompt_parser_hijack.py @@ -24,7 +24,7 @@ def get_multicond_prompt_list_hijack(prompts, original_function): return original_function(webui_prompts) -def parse_prompts(prompts: List[str]) -> neutral_prompt_parser.PromptExpr: +def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]: exprs = [] for prompt in prompts: expr = neutral_prompt_parser.parse_root(prompt)