Skip to content

Commit

Permalink
Fix sample image gen to work with block swap
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 28, 2024
1 parent 1065dd1 commit af8e216
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def do_sample(
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)

mmdit.prepare_block_swap_before_forward()
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
Expand All @@ -385,6 +386,7 @@ def do_sample(
x = x + d * dt
x = x.to(dtype)

mmdit.prepare_block_swap_before_forward()
return x


Expand Down

0 comments on commit af8e216

Please sign in to comment.