From 0031d916f0fa035d5d48a25fcabadc149bfbb639 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 25 Oct 2024 23:20:38 +0900 Subject: [PATCH] add latent scaling/shifting --- sd3_train_network.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sd3_train_network.py b/sd3_train_network.py index 0f4ca93ef..ecacf16cc 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -6,7 +6,7 @@ import torch from accelerate import Accelerator -from library import strategy_sd3, utils +from library import sd3_models, strategy_sd3, utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -268,7 +267,7 @@ def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): - return latents + return sd3_models.SDVAE.process_in(latents) def get_noise_pred_and_target( self,