From 2bb0f547d72cd0256cafebd46d0f61fbe54012ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:33:12 +0900 Subject: [PATCH] update grad hook creation to fix TE lr in sd3 fine tuning --- flux_train.py | 19 ++++++++++++------- library/train_util.py | 1 + sd3_train.py | 15 +++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/flux_train.py b/flux_train.py index ad2c7722b..a89e2f139 100644 --- a/flux_train.py +++ b/flux_train.py @@ -80,7 +80,9 @@ def train(args): assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -480,13 +482,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook - parameter.register_post_accumulate_grad_hook(grad_hook) + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers diff --git a/library/train_util.py b/library/train_util.py index e1dfeecdb..25cf7640d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5913,6 +5913,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/sd3_train.py b/sd3_train.py index a4fc2eec8..96ec951b9 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -606,13 +606,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook - parameter.register_post_accumulate_grad_hook(grad_hook) + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers