Skip to content

Commit

Permalink
fix get_webui_denoised (#60)
Browse files Browse the repository at this point in the history
* fix batch size

* fix get_webui_denoised

---------

Co-authored-by: ljleb <set>
  • Loading branch information
ljleb authored Jan 28, 2024
1 parent e395a11 commit f795be8
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,13 @@ def get_webui_denoised(
uncond = x_out[-text_uncond.shape[0]:]
sliced_batch_x_out = []
sliced_batch_cond_indices = []
index_in = 0

for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, index_in, len(sliced_batch_x_out))
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, 0)
if sliced_cond_indices:
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_batch_x_out.extend(sliced_x_out)
index_in += prompt.accept(neutral_prompt_parser.FlatSizeVisitor())

sliced_batch_x_out += list(uncond)
sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0)
Expand All @@ -73,31 +71,23 @@ class CombineDenoiseArgs:

@dataclasses.dataclass
class GatherWebuiCondsVisitor:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [args.x_out[args.cond_indices[index_in][0]]], [(index_out, args.cond_indices[index_in][1])]

def visit_composite_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
sliced_x_out = []
sliced_cond_indices = []

for child in that.children:
if child.conciliation is None:
index_offset = index_out + len(sliced_x_out)
child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor(), args, index_in, index_offset)
sliced_x_out.extend(child_x_out)
sliced_cond_indices.extend(child_cond_indices)
index_out = len(sliced_x_out)
child_x_out = child.accept(CondDeltaVisitor(), args, index_in)
child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in)
child_x_out += args.uncond
sliced_x_out.append(child_x_out)
sliced_cond_indices.append((index_out, child.weight))

index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())

Expand Down

0 comments on commit f795be8

Please sign in to comment.