Skip to content

Commit

Permalink
Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encod…
Browse files Browse the repository at this point in the history
…ers LoRA not trained
  • Loading branch information
kohya-ss committed Oct 27, 2024
1 parent b649bbf commit db2b4d4
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 17 deletions.
18 changes: 18 additions & 0 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,24 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
)
parser.add_argument(
"--clip_l_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--clip_g_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--t5_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
)

# copy from Diffusers
parser.add_argument(
Expand Down
93 changes: 83 additions & 10 deletions library/strategy_sd3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import glob
import random
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
Expand Down Expand Up @@ -48,13 +49,23 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:


class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None:
def __init__(
self,
apply_lg_attn_mask: Optional[bool] = None,
apply_t5_attn_mask: Optional[bool] = None,
l_dropout_rate: float = 0.0,
g_dropout_rate: float = 0.0,
t5_dropout_rate: float = 0.0,
) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
self.l_dropout_rate = l_dropout_rate
self.g_dropout_rate = g_dropout_rate
self.t5_dropout_rate = t5_dropout_rate

def encode_tokens(
self,
Expand All @@ -63,6 +74,7 @@ def encode_tokens(
tokens: List[torch.Tensor],
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
Expand Down Expand Up @@ -91,37 +103,92 @@ def encode_tokens(
g_attn_mask = None
t5_attn_mask = None

# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings

if l_tokens is None or clip_l is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
else:
with torch.no_grad():
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"

drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
if drop_l:
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype)
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype)
if l_attn_mask is not None:
l_attn_mask = torch.zeros_like(l_attn_mask)
else:
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None

prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
l_pooled = prompt_embeds[0]
l_out = prompt_embeds.hidden_states[-2]

drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if drop_g:
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
if g_attn_mask is not None:
g_attn_mask = torch.zeros_like(g_attn_mask)
else:
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
g_pooled = prompt_embeds[0]
g_out = prompt_embeds.hidden_states[-2]

lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
lg_out = torch.cat([l_out, g_out], dim=-1)

if t5xxl is None or t5_tokens is None:
t5_out = None
else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
with torch.no_grad():
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if drop_t5:
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype)
if t5_attn_mask is not None:
t5_attn_mask = torch.zeros_like(t5_attn_mask)
else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)

# masks are used for attention masking in transformer
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]

def drop_cached_text_encoder_outputs(
self,
lg_out: torch.Tensor,
t5_out: torch.Tensor,
lg_pooled: torch.Tensor,
l_attn_mask: torch.Tensor,
g_attn_mask: torch.Tensor,
t5_attn_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
if lg_out is not None:
for i in range(lg_out.shape[0]):
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
if drop_l:
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
if l_attn_mask is not None:
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
if drop_g:
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
if g_attn_mask is not None:
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])

if t5_out is not None:
for i in range(t5_out.shape[0]):
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
if drop_t5:
t5_out[i] = torch.zeros_like(t5_out[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])

return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask

def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -207,8 +274,14 @@ def cache_batch_outputs(

tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# always disable dropout during caching
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)

if lg_out.dtype == torch.bfloat16:
Expand Down
15 changes: 12 additions & 3 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def train(args):
# assert (
# not args.train_text_encoder or not args.cache_text_encoder_outputs
# ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True

assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
Expand Down Expand Up @@ -232,7 +237,9 @@ def train(args):
assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified"

# prepare text encoding strategy
text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask)
text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate
)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)

# 学習を準備する:モデルを適切な状態にする
Expand Down Expand Up @@ -311,6 +318,7 @@ def train(args):
tokens_and_masks,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
enable_dropout=False,
)

accelerator.wait_for_everyone()
Expand Down Expand Up @@ -863,6 +871,7 @@ def optimizer_hook(parameter: torch.Tensor):

text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list
if args.use_t5xxl_cache_only:
lg_out = None
Expand All @@ -878,7 +887,7 @@ def optimizer_hook(parameter: torch.Tensor):
if lg_out is None:
# not cached or training, so get from text encoders
input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
with torch.set_grad_enabled(train_clip):
# TODO support weighted captions
# text models in sd3_models require "cpu" for input_ids
input_ids_clip_l = input_ids_clip_l.to("cpu")
Expand All @@ -891,7 +900,7 @@ def optimizer_hook(parameter: torch.Tensor):

if t5_out is None:
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.no_grad():
with torch.set_grad_enabled(train_t5xxl):
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
Expand Down
16 changes: 15 additions & 1 deletion sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ def get_latents_caching_strategy(self, args):
return latents_caching_strategy

def get_text_encoding_strategy(self, args):
return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask)
return strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5xxl_dropout_rate,
)

def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
Expand Down Expand Up @@ -408,6 +414,14 @@ def forward(hidden_states):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)

def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list


def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
Expand Down
13 changes: 10 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)

def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass

# endregion

def train(self, args):
Expand Down Expand Up @@ -1030,9 +1033,9 @@ def load_model_hook(models, input_dir):

# callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
on_step_start_for_network = lambda *args, **kwargs: None

# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
Expand Down Expand Up @@ -1113,7 +1116,10 @@ def remove_model(old_ckpt_name):
continue

with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
on_step_start_for_network(text_encoder, unet)

# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
Expand Down Expand Up @@ -1146,6 +1152,7 @@ def remove_model(old_ckpt_name):
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
Expand Down

0 comments on commit db2b4d4

Please sign in to comment.