From 420a180d938c7b5a6e3006b1719dbfeaae72a2cc Mon Sep 17 00:00:00 2001 From: recris Date: Wed, 27 Nov 2024 18:11:51 +0000 Subject: [PATCH] Implement pseudo Huber loss for Flux and SD3 --- fine_tune.py | 6 +-- flux_train.py | 2 +- flux_train_network.py | 2 +- library/train_util.py | 74 ++++++++++++++++------------ sd3_train.py | 2 +- sd3_train_network.py | 2 +- sdxl_train.py | 6 +-- sdxl_train_control_net.py | 4 +- sdxl_train_control_net_lllite.py | 4 +- sdxl_train_control_net_lllite_old.py | 6 ++- train_controlnet.py | 6 +-- train_db.py | 4 +- train_network.py | 9 ++-- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 6 ++- 15 files changed, 76 insertions(+), 61 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0090bd190..70959a751 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -397,7 +397,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) @@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/flux_train.py b/flux_train.py index a89e2f139..f6e43b27a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor): # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/flux_train_network.py b/flux_train_network.py index 679db62b6..04287f399 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -468,7 +468,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..c204ebd38 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( @@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") - - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = torch.exp(-alpha * timesteps) - elif args.huber_schedule == "snr": - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = torch.full((b_size,), args.huber_c) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - huber_c = huber_c.to(device) - elif args.loss_type == "l2": - huber_c = None # may be anything, as it's not used - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - timesteps = timesteps.long().to(device) - return timesteps, huber_c +def get_timesteps(min_timestep, max_timestep, b_size, device): + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + timesteps = timesteps.long() + return timesteps def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): @@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -5878,24 +5866,46 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps + + +def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, 'alphas_cumprod'): + raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] + args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler ): - if loss_type == "l2": + if args.loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif loss_type == "l1": + elif args.loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) - elif loss_type == "huber": + elif args.loss_type == "huber": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif loss_type == "smooth_l1": + elif args.loss_type == "smooth_l1": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5903,7 +5913,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}") return loss diff --git a/sd3_train.py b/sd3_train.py index 96ec951b9..cf2bdf938 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor): # ) # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/sd3_train_network.py b/sd3_train_network.py index 1726e325f..fb7711bda 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -378,7 +378,7 @@ def get_noise_pred_and_target( target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/sdxl_train.py b/sdxl_train.py index e26f4aa19..1bc27ec6c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -695,7 +695,7 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -720,7 +720,7 @@ def optimizer_hook(parameter: torch.Tensor): ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) @@ -738,7 +738,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 24080afbd..d0051d18f 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -512,7 +512,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -534,7 +534,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 2946c97d4..66214f5df 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -463,7 +463,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -485,7 +485,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 2d4465234..5e10654b9 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +426,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index 8c7882c8f..da7a08d69 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -464,8 +464,8 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c( - args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + timesteps = train_util.get_timesteps( + 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device ) # Add noise to the latents according to the noise magnitude at each timestep @@ -499,7 +499,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/train_db.py b/train_db.py index 51e209f34..a185b31b3 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -385,7 +385,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_network.py b/train_network.py index bbf381f99..c7d4f5dc5 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ def get_noise_pred_and_target( ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -244,7 +244,7 @@ def get_noise_pred_and_target( network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, huber_c, None + return noise_pred, target, timesteps, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): if args.min_snr_gamma: @@ -806,6 +806,7 @@ def load_model_hook(models, input_dir): "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), @@ -1193,7 +1194,7 @@ def remove_model(old_ckpt_name): text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -1207,7 +1208,7 @@ def remove_model(old_ckpt_name): ) loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5f4657eb9..9e1e57c48 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -585,7 +585,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -602,7 +602,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 52d525fc5..944733602 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -461,7 +461,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,7 +473,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3])