Skip to content

Commit

Permalink
add latent scaling/shifting
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 25, 2024
1 parent d2c549d commit 0031d91
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from accelerate import Accelerator
from library import strategy_sd3, utils
from library import sd3_models, strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand All @@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
Expand Down Expand Up @@ -268,7 +267,7 @@ def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)

def shift_scale_latents(self, args, latents):
return latents
return sd3_models.SDVAE.process_in(latents)

def get_noise_pred_and_target(
self,
Expand Down

0 comments on commit 0031d91

Please sign in to comment.