diff --git a/library/train_util.py b/library/train_util.py index 31e37bf70..64a112597 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4100,8 +4100,10 @@ def get_optimizer(args, trainable_params): optimizer_class = sf.SGDScheduleFree logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + optimizer_class = None + + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う diff --git a/train_network.py b/train_network.py index f287acacd..8c2abda58 100644 --- a/train_network.py +++ b/train_network.py @@ -53,7 +53,7 @@ def __init__(self): # TODO 他のスクリプトと共通化する def generate_step_logs( - self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, optimizer, keys_scaled=None, mean_norm=None, maximum_norm=None ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -79,6 +79,12 @@ def generate_step_logs( logs["lr/d*lr"] = ( lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) else: idx = 0 if not args.network_train_unet_only: @@ -91,6 +97,12 @@ def generate_step_logs( logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + ) return logs @@ -965,7 +977,7 @@ def remove_model(old_ckpt_name): progress_bar.set_postfix(**{**max_mean_logs, **logs}) if args.logging_dir is not None: - logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) + logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, optimizer, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: