Skip to content

Commit

Permalink
update grad hook creation to fix TE lr in sd3 fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 14, 2024
1 parent 2cb7a6d commit 2bb0f54
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
19 changes: 12 additions & 7 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def train(args):

assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
) or not args.cpu_offload_checkpointing, (
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
)

cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
Expand Down Expand Up @@ -480,13 +482,16 @@ def train(args):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:

def grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None

return grad_hook

parameter.register_post_accumulate_grad_hook(grad_hook)
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))

elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
Expand Down
1 change: 1 addition & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5913,6 +5913,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names.append("unet")
names.append("text_encoder1")
names.append("text_encoder2")
names.append("text_encoder3") # SD3

append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)

Expand Down
15 changes: 9 additions & 6 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,16 @@ def train(args):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:

def grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None

return grad_hook

parameter.register_post_accumulate_grad_hook(grad_hook)
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))

elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
Expand Down

0 comments on commit 2bb0f54

Please sign in to comment.