diff --git a/fine_tune.py b/fine_tune.py index 893066f70..a86a483a0 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,10 +10,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -193,14 +196,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params + + trainable_params = [] + if args.learning_rate_te is None or not args.train_text_encoder: + for m in training_models: + trainable_params.extend(m.parameters()) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -340,7 +349,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -476,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) return parser diff --git a/train_db.py b/train_db.py index 59a124a26..fd8e466e5 100644 --- a/train_db.py +++ b/train_db.py @@ -11,10 +11,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -164,11 +167,17 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") if train_text_encoder: - # wightout list, adamw8bit is crashed - trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + if args.learning_rate_te is None: + # wightout list, adamw8bit is crashed + trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] else: trainable_params = unet.parameters() - + _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する @@ -461,6 +470,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) parser.add_argument( "--no_token_padding", action="store_true",