Skip to content

Commit

Permalink
support separate LR for Text Encoder for SD1/2
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 29, 2023
1 parent e72020a commit 96d877b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
27 changes: 21 additions & 6 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -193,14 +196,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for m in training_models:
m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params

trainable_params = []
if args.learning_rate_te is None or not args.train_text_encoder:
for m in training_models:
trainable_params.extend(m.parameters())
else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]

# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
Expand Down Expand Up @@ -340,7 +349,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,:
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
Expand Down Expand Up @@ -476,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser:

parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)

return parser

Expand Down
21 changes: 18 additions & 3 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -164,11 +167,17 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if train_text_encoder:
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
if args.learning_rate_te is None:
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]
else:
trainable_params = unet.parameters()

_, _, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
Expand Down Expand Up @@ -461,6 +470,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_token_padding",
action="store_true",
Expand Down

0 comments on commit 96d877b

Please sign in to comment.