Skip to content

Commit

Permalink
support separate learning rates for TE1/2
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 29, 2023
1 parent a9ed4ed commit 01d929e
Showing 1 changed file with 68 additions and 30 deletions.
98 changes: 68 additions & 30 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -272,20 +275,32 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.wait_for_everyone()

# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
train_unet = args.learning_rate > 0
train_text_encoder1 = False
train_text_encoder2 = False

if args.train_text_encoder:
# TODO each option for two text encoders?
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1)
training_models.append(text_encoder2)
# set require_grad=True later
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0

# caching one text encoder output is not supported
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
if not train_text_encoder2:
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder2.requires_grad_(train_text_encoder2)
text_encoder1.train(train_text_encoder1)
text_encoder2.train(train_text_encoder2)
else:
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
Expand Down Expand Up @@ -313,28 +328,33 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)

for m in training_models:
m.requires_grad_(True)
unet.requires_grad_(train_unet)
if not train_unet:
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared

if block_lrs is None:
params_to_optimize = [
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
]
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
training_models = []
params_to_optimize = []
if train_unet:
training_models.append(unet)
if block_lrs is None:
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
else:
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))

for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({
"params": list(m.parameters()),
"lr": args.learning_rate_te or args.learning_rate
})
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
if train_text_encoder2:
training_models.append(text_encoder2)
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})

# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()

accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")

Expand Down Expand Up @@ -386,16 +406,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
)

# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if train_unet:
unet = accelerator.prepare(unet)
(unet,) = train_util.transform_models_if_DDP([unet])
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down Expand Up @@ -461,7 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
Expand Down Expand Up @@ -547,7 +568,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

target = noise

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss:
if (
args.min_snr_gamma
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
Expand Down Expand Up @@ -725,7 +751,19 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")

parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
)
parser.add_argument(
"--learning_rate_te2",
type=float,
default=None,
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
)

parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
Expand Down

0 comments on commit 01d929e

Please sign in to comment.