Skip to content

Commit

Permalink
Fix to work without latent cache #1758
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 6, 2024
1 parent 5e32ee2 commit 4384903
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,9 @@ def optimizer_hook(parameter: torch.Tensor):
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = vae.encode(batch["images"])
latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to(
accelerator.device, dtype=weight_dtype
)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
Expand Down Expand Up @@ -927,7 +929,7 @@ def optimizer_hook(parameter: torch.Tensor):
if t5_out is None:
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.set_grad_enabled(train_t5xxl):
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
input_ids_t5xxl = input_ids_t5xxl.to("cpu")
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
)
Expand Down

0 comments on commit 4384903

Please sign in to comment.