diff --git a/sd3_train_network.py b/sd3_train_network.py index 0f4ca93ef..ecacf16cc 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -6,7 +6,7 @@ import torch from accelerate import Accelerator -from library import strategy_sd3, utils +from library import sd3_models, strategy_sd3, utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -268,7 +267,7 @@ def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): - return latents + return sd3_models.SDVAE.process_in(latents) def get_noise_pred_and_target( self,