Skip to content

Commit

Permalink
fix batch size (#61)
Browse files Browse the repository at this point in the history
* fix batch size

* fix

---------

Co-authored-by: ljleb <set>
  • Loading branch information
ljleb authored Jan 28, 2024
1 parent f795be8 commit 1cd847a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_webui_denoised(

for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, 0)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, 0, len(sliced_batch_x_out))
if sliced_cond_indices:
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_batch_x_out.extend(sliced_x_out)
Expand Down Expand Up @@ -76,18 +76,19 @@ def visit_composite_prompt(
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
sliced_x_out = []
sliced_cond_indices = []

for child in that.children:
if child.conciliation is None:
index_out = len(sliced_x_out)
index_offset = index_out + len(sliced_x_out)
child_x_out = child.accept(CondDeltaVisitor(), args, index_in)
child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in)
child_x_out += args.uncond
sliced_x_out.append(child_x_out)
sliced_cond_indices.append((index_out, child.weight))
sliced_cond_indices.append((index_offset, child.weight))

index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())

Expand Down

0 comments on commit 1cd847a

Please sign in to comment.