From 43849030cf35a7c854311e0bee9cb8a92b77dd83 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 6 Nov 2024 21:33:28 +0900 Subject: [PATCH 1/8] Fix to work without latent cache #1758 --- sd3_train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index e03d1708b..b8a0d04fa 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -885,7 +885,9 @@ def optimizer_hook(parameter: torch.Tensor): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = vae.encode(batch["images"]) + latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -927,7 +929,7 @@ def optimizer_hook(parameter: torch.Tensor): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.set_grad_enabled(train_t5xxl): - input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) From 40ed54bfc0ca666c45a4a5d4b7a3064612371005 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 09:53:54 +0000 Subject: [PATCH 2/8] Simplify Timestep weighting * Remove diffusers dependency in ts & sigma calc * support Shift setting * Add uniform distribution * Default to Uniform distribution and shift 1 --- library/sd3_train_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 69878750e..bfe752d5e 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -253,12 +253,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # copy from Diffusers + # Dependencies of Diffusers noise sampler has been removed for clearity. parser.add_argument( "--weighting_scheme", type=str, - default="logit_normal", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", ) parser.add_argument( @@ -279,8 +279,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", ) - - + parser.add_argument( + "--training_shift", + type=float, + default=1.0, + help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", + ) + def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: @@ -965,14 +970,20 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + shift = args.training_shift + + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) - # Add noise according to flow matching. - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + indices = (u * (t_max-t_min) + t_min).long() + timesteps = indices.to(device=device, dtype=dtype) + + # sigmas according to dlowmatching + sigmas = timesteps / 1000 + sigmas = sigmas.view(-1,1,1,1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input, timesteps, sigmas - -# endregion From e54462a4a9cb3d01c5635f8c191d28cbccfba6e0 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 09:54:12 +0000 Subject: [PATCH 3/8] Fix SD3 trained lora loading and merging --- networks/lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index efe202451..ce6d1a16f 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -601,7 +601,7 @@ def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) ): apply_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT): + elif key.startswith(LoRANetwork.LORA_PREFIX_SD3): apply_unet = True if apply_text_encoder: From bafd10d558bf318ccd7059c2b4dce2775b5758da Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 18:21:04 +0800 Subject: [PATCH 4/8] Fix typo --- library/sd3_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index bfe752d5e..afbe34cf5 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -980,7 +980,7 @@ def get_noisy_model_input_and_timesteps( indices = (u * (t_max-t_min) + t_min).long() timesteps = indices.to(device=device, dtype=dtype) - # sigmas according to dlowmatching + # sigmas according to flowmatching sigmas = timesteps / 1000 sigmas = sigmas.view(-1,1,1,1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents From 5e86323f12178605c0b99bc914b4bd970900ce75 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 21:27:12 +0900 Subject: [PATCH 5/8] Update README and clean-up the code for SD3 timesteps --- README.md | 13 ++++++++++++- library/config_util.py | 2 +- library/sd3_models.py | 2 +- library/sd3_train_utils.py | 17 +++++++++-------- sd3_train.py | 8 ++++---- sd3_train_network.py | 7 +++---- 6 files changed, 30 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index fb087c234..dba76a3c5 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,13 @@ The command to install PyTorch is as follows: ### Recent Updates +Nov 7, 2024: + +- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! + - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. + - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). + - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + Oct 31, 2024: - Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. @@ -641,6 +648,7 @@ Here are the arguments. The arguments and sample settings are still experimental - `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0. - `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training. - `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below. +- `--training_shift` is the shift value for the training distribution of timesteps. The default is 1.0 (uniform distribution, no shift). If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. Other options are described below. @@ -681,7 +689,10 @@ Other options are described below. - Same as FLUX.1 for data preparation. - If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__ - +6. Weighting scheme and training shift: + - The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1). + - The uniform distribution is the default. If you want to change the distribution, see `--help` for options. + - `--training_shift` is the shift value for the training distribution of timesteps. Technical details of multi-resolution training for SD3.5M: diff --git a/library/config_util.py b/library/config_util.py index fc1fbf46d..12d0be173 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu secondary_separator: {subset.secondary_separator} enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} diff --git a/library/sd3_models.py b/library/sd3_models.py index b09a57dbd..89225fe4d 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -871,7 +871,7 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None - # remove duplcates and sort latent sizes in ascending order + # remove duplicates and sort latent sizes in ascending order latent_sizes = list(set(latent_sizes)) latent_sizes = sorted(latent_sizes) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index afbe34cf5..38f3c25f4 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -253,7 +253,7 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # Dependencies of Diffusers noise sampler has been removed for clearity. + # Dependencies of Diffusers noise sampler has been removed for clarity. parser.add_argument( "--weighting_scheme", type=str, @@ -285,7 +285,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.0, help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", ) - + + def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: @@ -956,9 +957,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +# endregion + + +def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz = latents.shape[0] # Sample a random timestep for each image @@ -977,13 +979,12 @@ def get_noisy_model_input_and_timesteps( # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) u = (u * shift) / (1 + (shift - 1) * u) - indices = (u * (t_max-t_min) + t_min).long() + indices = (u * (t_max - t_min) + t_min).long() timesteps = indices.to(device=device, dtype=dtype) # sigmas according to flowmatching sigmas = timesteps / 1000 - sigmas = sigmas.view(-1,1,1,1) + sigmas = sigmas.view(-1, 1, 1, 1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input, timesteps, sigmas - diff --git a/sd3_train.py b/sd3_train.py index b8a0d04fa..24ecbfb7d 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -811,8 +811,8 @@ def optimizer_hook(parameter: torch.Tensor): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + # noise_scheduler_copy = copy.deepcopy(noise_scheduler) if accelerator.is_main_process: init_kwargs = {} @@ -940,11 +940,11 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - bsz = latents.shape[0] + # bsz = latents.shape[0] # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + args, latents, noise, accelerator.device, weight_dtype ) # debug: NaN check for all inputs diff --git a/sd3_train_network.py b/sd3_train_network.py index 0739e094d..bb02c7ac7 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -275,9 +275,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke ) def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - # shift 3.0 is the default value - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # this scheduler is not used in training, but used to get num_train_timesteps etc. + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler def encode_images_to_latents(self, args, accelerator, vae, images): @@ -304,7 +303,7 @@ def get_noise_pred_and_target( # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( - args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + args, latents, noise, accelerator.device, weight_dtype ) # ensure the hidden state will require grad From f264f4091f734b4e4011257b8571ef97315a1343 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 7 Nov 2024 21:30:31 +0900 Subject: [PATCH 6/8] Update README.md --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index dba76a3c5..9273fc8fb 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Nov 7, 2024: - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + - The effect of a shift in uniform distribution is shown in the figure below. + - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) Oct 31, 2024: @@ -693,7 +695,8 @@ Other options are described below. - The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1). - The uniform distribution is the default. If you want to change the distribution, see `--help` for options. - `--training_shift` is the shift value for the training distribution of timesteps. - + - The effect of a shift in uniform distribution is shown in the figure below. + - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) Technical details of multi-resolution training for SD3.5M: From 5eb6d209d5b28d43bf611e0934297703eb041d07 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 20:33:31 +0800 Subject: [PATCH 7/8] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9273fc8fb..fe7c506cb 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Nov 7, 2024: - The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled (training more on image details), and if more than 1.0, the side closer to noise is more sampled (training more on overall structure). - The effect of a shift in uniform distribution is shown in the figure below. - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) From e5ac09574928ec02fba5fe78267764d26bb7faa6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 21:39:47 +0900 Subject: [PATCH 8/8] add about dev and sd3 branch to README --- README-ja.md | 4 ++++ README.md | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/README-ja.md b/README-ja.md index 4ae6b2334..27cc56c34 100644 --- a/README-ja.md +++ b/README-ja.md @@ -3,6 +3,10 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ [README in English](./README.md) ←更新情報はこちらにあります +開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。 + +FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。 + GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。 以下のスクリプトがあります。 diff --git a/README.md b/README.md index 1e7d49afe..51ac07a1f 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,11 @@ This repository contains training, generation and utility scripts for Stable Dif [日本語版READMEはこちら](./README-ja.md) +The development version is in the `dev` branch. Please check the dev branch for the latest changes. + +FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch. + + For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais! This repository contains the scripts for: