From f795be898b583451849e0fcb497ea7a80a37643b Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 28 Jan 2024 01:37:34 -0500 Subject: [PATCH] fix `get_webui_denoised` (#60) * fix batch size * fix get_webui_denoised --------- Co-authored-by: ljleb --- lib_neutral_prompt/cfg_denoiser_hijack.py | 24 +++++++---------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index 696a060..5a859e0 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -41,15 +41,13 @@ def get_webui_denoised( uncond = x_out[-text_uncond.shape[0]:] sliced_batch_x_out = [] sliced_batch_cond_indices = [] - index_in = 0 for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)): args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices) - sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, index_in, len(sliced_batch_x_out)) + sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, 0) if sliced_cond_indices: sliced_batch_cond_indices.append(sliced_cond_indices) sliced_batch_x_out.extend(sliced_x_out) - index_in += prompt.accept(neutral_prompt_parser.FlatSizeVisitor()) sliced_batch_x_out += list(uncond) sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0) @@ -73,31 +71,23 @@ class CombineDenoiseArgs: @dataclasses.dataclass class GatherWebuiCondsVisitor: - def visit_leaf_prompt( - self, - that: neutral_prompt_parser.CompositePrompt, - args: CombineDenoiseArgs, - index_in: int, - index_out: int, - ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: - return [args.x_out[args.cond_indices[index_in][0]]], [(index_out, args.cond_indices[index_in][1])] - def visit_composite_prompt( self, that: neutral_prompt_parser.CompositePrompt, args: CombineDenoiseArgs, index_in: int, - index_out: int, ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: sliced_x_out = [] sliced_cond_indices = [] for child in that.children: if child.conciliation is None: - index_offset = index_out + len(sliced_x_out) - child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor(), args, index_in, index_offset) - sliced_x_out.extend(child_x_out) - sliced_cond_indices.extend(child_cond_indices) + index_out = len(sliced_x_out) + child_x_out = child.accept(CondDeltaVisitor(), args, index_in) + child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in) + child_x_out += args.uncond + sliced_x_out.append(child_x_out) + sliced_cond_indices.append((index_out, child.weight)) index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())