Skip to content

Commit

Permalink
Merge branch 'sd3' into faster-block-swap
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 7, 2024
2 parents aab943c + 123474d commit b8d3fec
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 28 deletions.
4 changes: 4 additions & 0 deletions README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ

[README in English](./README.md) ←更新情報はこちらにあります

開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。

FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。

GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。

以下のスクリプトがあります。
Expand Down
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ The command to install PyTorch is as follows:

### Recent Updates

Nov 7, 2024:

- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233!
- Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details.
- Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`).
- A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled (training more on image details), and if more than 1.0, the side closer to noise is more sampled (training more on overall structure).
- The effect of a shift in uniform distribution is shown in the figure below.
- ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350)

Oct 31, 2024:

- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details.
Expand Down Expand Up @@ -641,6 +650,7 @@ Here are the arguments. The arguments and sample settings are still experimental
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0.
- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training.
- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below.
- `--training_shift` is the shift value for the training distribution of timesteps. The default is 1.0 (uniform distribution, no shift). If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled.
Other options are described below.
Expand Down Expand Up @@ -681,8 +691,12 @@ Other options are described below.
- Same as FLUX.1 for data preparation.
- If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__
6. Weighting scheme and training shift:
- The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1).
- The uniform distribution is the default. If you want to change the distribution, see `--help` for options.
- `--training_shift` is the shift value for the training distribution of timesteps.
- The effect of a shift in uniform distribution is shown in the figure below.
- ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350)
Technical details of multi-resolution training for SD3.5M:
Expand Down Expand Up @@ -776,6 +790,11 @@ Not available yet.
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
This repository contains the scripts for:
Expand Down
2 changes: 1 addition & 1 deletion library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
Expand Down
2 changes: 1 addition & 1 deletion library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None

# remove duplcates and sort latent sizes in ascending order
# remove duplicates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))
latent_sizes = sorted(latent_sizes)

Expand Down
38 changes: 25 additions & 13 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)

# copy from Diffusers
# Dependencies of Diffusers noise sampler has been removed for clarity.
parser.add_argument(
"--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
)
parser.add_argument(
Expand All @@ -279,6 +279,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
)
parser.add_argument(
"--training_shift",
type=float,
default=1.0,
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
)


def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
Expand Down Expand Up @@ -951,9 +957,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting


def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# endregion


def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]

# Sample a random timestep for each image
Expand All @@ -965,14 +972,19 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
t_min = args.min_timestep if args.min_timestep is not None else 0
t_max = args.max_timestep if args.max_timestep is not None else 1000
shift = args.training_shift

# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)

return noisy_model_input, timesteps, sigmas
indices = (u * (t_max - t_min) + t_min).long()
timesteps = indices.to(device=device, dtype=dtype)

# sigmas according to flowmatching
sigmas = timesteps / 1000
sigmas = sigmas.view(-1, 1, 1, 1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents

# endregion
return noisy_model_input, timesteps, sigmas
2 changes: 1 addition & 1 deletion networks/lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None):
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5)
):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT):
elif key.startswith(LoRANetwork.LORA_PREFIX_SD3):
apply_unet = True

if apply_text_encoder:
Expand Down
14 changes: 8 additions & 6 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ def optimizer_hook(parameter: torch.Tensor):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0

noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
# noise_scheduler_copy = copy.deepcopy(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
Expand Down Expand Up @@ -885,7 +885,9 @@ def optimizer_hook(parameter: torch.Tensor):
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = vae.encode(batch["images"])
latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to(
accelerator.device, dtype=weight_dtype
)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
Expand Down Expand Up @@ -927,7 +929,7 @@ def optimizer_hook(parameter: torch.Tensor):
if t5_out is None:
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.set_grad_enabled(train_t5xxl):
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
input_ids_t5xxl = input_ids_t5xxl.to("cpu")
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
)
Expand All @@ -938,11 +940,11 @@ def optimizer_hook(parameter: torch.Tensor):

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# bsz = latents.shape[0]

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
args, latents, noise, accelerator.device, weight_dtype
)

# debug: NaN check for all inputs
Expand Down
7 changes: 3 additions & 4 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke
)

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# shift 3.0 is the default value
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# this scheduler is not used in training, but used to get num_train_timesteps etc.
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
return noise_scheduler

def encode_images_to_latents(self, args, accelerator, vae, images):
Expand All @@ -304,7 +303,7 @@ def get_noise_pred_and_target(

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
args, latents, noise, accelerator.device, weight_dtype
)

# ensure the hidden state will require grad
Expand Down

0 comments on commit b8d3fec

Please sign in to comment.