Skip to content

Commit

Permalink
Fix to work dropout_rate for TEs
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 27, 2024
1 parent d4f7849 commit 1065dd1
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 71 deletions.
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def get_noise_pred_and_target(
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype.is_floating_point:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
Expand Down
1 change: 1 addition & 0 deletions library/strategy_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def cache_batch_outputs(
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)


Expand Down
142 changes: 99 additions & 43 deletions library/strategy_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,67 +89,122 @@ def encode_tokens(
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask

l_tokens, g_tokens, t5_tokens = tokens[:3]

if len(tokens) > 3:
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:]
if not apply_lg_attn_mask:
l_attn_mask = None
g_attn_mask = None
if not apply_t5_attn_mask:
t5_attn_mask = None
else:
l_attn_mask = None
g_attn_mask = None
t5_attn_mask = None
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens

# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings

if l_tokens is None or clip_l is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
l_attn_mask = None
g_attn_mask = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"

drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
if drop_l:
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype)
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype)
if l_attn_mask is not None:
l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device)
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
batch_size, l_seq_len = l_tokens.shape
g_seq_len = g_tokens.shape[1]

non_drop_l_indices = []
non_drop_g_indices = []
for i in range(l_tokens.shape[0]):
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if not drop_l:
non_drop_l_indices.append(i)
if not drop_g:
non_drop_g_indices.append(i)

# filter out dropped members
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
l_tokens = l_tokens[non_drop_l_indices]
l_attn_mask = l_attn_mask[non_drop_l_indices]
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
g_tokens = g_tokens[non_drop_g_indices]
g_attn_mask = g_attn_mask[non_drop_g_indices]

# call clip_l for non-dropped members
if len(non_drop_l_indices) > 0:
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
prompt_embeds = clip_l(
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_l_pooled = prompt_embeds[0]
nd_l_out = prompt_embeds.hidden_states[-2]
if len(non_drop_g_indices) > 0:
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
prompt_embeds = clip_g(
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_g_pooled = prompt_embeds[0]
nd_g_out = prompt_embeds.hidden_states[-2]

# fill in the dropped members
if len(non_drop_l_indices) == batch_size:
l_pooled = nd_l_pooled
l_out = nd_l_out
else:
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
l_pooled = prompt_embeds[0]
l_out = prompt_embeds.hidden_states[-2]

drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if drop_g:
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype)
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype)
if g_attn_mask is not None:
g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device)
# model output is always float32 because of the models are wrapped with Accelerator
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
if len(non_drop_l_indices) > 0:
l_pooled[non_drop_l_indices] = nd_l_pooled
l_out[non_drop_l_indices] = nd_l_out
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask

if len(non_drop_g_indices) == batch_size:
g_pooled = nd_g_pooled
g_out = nd_g_out
else:
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
g_pooled = prompt_embeds[0]
g_out = prompt_embeds.hidden_states[-2]

lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
if len(non_drop_g_indices) > 0:
g_pooled[non_drop_g_indices] = nd_g_pooled
g_out[non_drop_g_indices] = nd_g_out
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask

lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out = torch.cat([l_out, g_out], dim=-1)

if t5xxl is None or t5_tokens is None:
t5_out = None
t5_attn_mask = None
else:
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if drop_t5:
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype)
if t5_attn_mask is not None:
t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device)
# drop some members of the batch: we do not call t5xxl for dropped members
batch_size, t5_seq_len = t5_tokens.shape
non_drop_t5_indices = []
for i in range(t5_tokens.shape[0]):
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if not drop_t5:
non_drop_t5_indices.append(i)

# filter out dropped members
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
t5_tokens = t5_tokens[non_drop_t5_indices]
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]

# call t5xxl for non-dropped members
if len(non_drop_t5_indices) > 0:
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
nd_t5_out, _ = t5xxl(
t5_tokens.to(t5xxl.device),
nd_t5_attn_mask if apply_t5_attn_mask else None,
return_dict=False,
output_hidden_states=True,
)

# fill in the dropped members
if len(non_drop_t5_indices) == batch_size:
t5_out = nd_t5_out
else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
if len(non_drop_t5_indices) > 0:
t5_out[non_drop_t5_indices] = nd_t5_out
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask

# masks are used for attention masking in transformer
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
Expand Down Expand Up @@ -322,6 +377,7 @@ def cache_batch_outputs(
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)


Expand Down
15 changes: 7 additions & 8 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_noise_pred_and_target(
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype.is_floating_point:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)

# Predict the noise residual
Expand Down Expand Up @@ -415,13 +415,12 @@ def forward(hidden_states):
prepare_fp8(text_encoder, weight_dtype)

def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# # drop cached text encoder outputs
# text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
# if text_encoder_outputs_list is not None:
# text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
# text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
# batch["text_encoder_outputs_list"] = text_encoder_outputs_list
pass
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list


def setup_parser() -> argparse.ArgumentParser:
Expand Down
19 changes: 0 additions & 19 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,16 +1151,6 @@ def remove_model(old_ckpt_name):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs

# if text_encoder_outputs_list is not None:
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list
# for i in range(len(lg_out)):
# print(
# f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
# f"cached T5: {t5_out[i].max()}, "
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
# )

if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
Expand Down Expand Up @@ -1193,15 +1183,6 @@ def remove_model(old_ckpt_name):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]

# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
# for i in range(len(lg_out)):
# print(
# f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
# f"train T5: {t5_out[i].max()}, "
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
# )

# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
args,
Expand Down

0 comments on commit 1065dd1

Please sign in to comment.