From 43849030cf35a7c854311e0bee9cb8a92b77dd83 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 6 Nov 2024 21:33:28 +0900 Subject: [PATCH] Fix to work without latent cache #1758 --- sd3_train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index e03d1708b..b8a0d04fa 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -885,7 +885,9 @@ def optimizer_hook(parameter: torch.Tensor): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = vae.encode(batch["images"]) + latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -927,7 +929,7 @@ def optimizer_hook(parameter: torch.Tensor): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.set_grad_enabled(train_t5xxl): - input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] )