Skip to content

Commit

Permalink
Implement pseudo Huber loss for Flux and SD3
Browse files Browse the repository at this point in the history
  • Loading branch information
recris committed Nov 27, 2024
1 parent 2a61fc0 commit 420a180
Show file tree
Hide file tree
Showing 15 changed files with 76 additions and 61 deletions.
6 changes: 3 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -397,7 +397,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand All @@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)

accelerator.backward(loss)
Expand Down
2 changes: 1 addition & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor):

# calculate loss
loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if weighting is not None:
loss = loss * weighting
Expand Down
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, None, weighting
return model_pred, target, timesteps, weighting

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
Expand Down
74 changes: 42 additions & 32 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)

parser.add_argument(
"--huber_scale",
type=float,
default=1.0,
help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)

parser.add_argument(
Expand Down Expand Up @@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)


def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")

if args.loss_type == "huber" or args.loss_type == "smooth_l1":
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = torch.exp(-alpha * timesteps)
elif args.huber_schedule == "snr":
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = torch.full((b_size,), args.huber_c)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
huber_c = huber_c.to(device)
elif args.loss_type == "l2":
huber_c = None # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")

timesteps = timesteps.long().to(device)
return timesteps, huber_c
def get_timesteps(min_timestep, max_timestep, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
timesteps = timesteps.long()
return timesteps


def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
Expand All @@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand All @@ -5878,32 +5866,54 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

return noise, noisy_latents, timesteps, huber_c
return noise, noisy_latents, timesteps


def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
b_size = timesteps.shape[0]
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, 'alphas_cumprod'):
raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")

return result


def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
):
if loss_type == "l2":
if args.loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
elif args.loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
elif args.loss_type == "huber":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
elif args.loss_type == "smooth_l1":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
return loss


Expand Down
2 changes: 1 addition & 1 deletion sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor):
# )
# calculate loss
loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand Down
2 changes: 1 addition & 1 deletion sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_noise_pred_and_target(

target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, None, weighting
return model_pred, target, timesteps, weighting

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
Expand Down
6 changes: 3 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def optimizer_hook(parameter: torch.Tensor):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -720,7 +720,7 @@ def optimizer_hook(parameter: torch.Tensor):
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand All @@ -738,7 +738,7 @@ def optimizer_hook(parameter: torch.Tensor):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)

accelerator.backward(loss)
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -534,7 +534,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -485,7 +485,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 4 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -426,7 +426,9 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
6 changes: 3 additions & 3 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
timesteps = train_util.get_timesteps(
0, noise_scheduler.config.num_train_timesteps, b_size, latents.device
)

# Add noise to the latents according to the noise magnitude at each timestep
Expand Down Expand Up @@ -499,7 +499,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand Down
4 changes: 2 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def train(args):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -385,7 +385,7 @@ def train(args):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand Down
9 changes: 5 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def get_noise_pred_and_target(
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# ensure the hidden state will require grad
if args.gradient_checkpointing:
Expand Down Expand Up @@ -244,7 +244,7 @@ def get_noise_pred_and_target(
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)

return noise_pred, target, timesteps, huber_c, None
return noise_pred, target, timesteps, None

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
if args.min_snr_gamma:
Expand Down Expand Up @@ -806,6 +806,7 @@ def load_model_hook(models, input_dir):
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_scale": args.huber_scale,
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
Expand Down Expand Up @@ -1193,7 +1194,7 @@ def remove_model(old_ckpt_name):
text_encoder_conds[i] = encoded_text_encoder_conds[i]

# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
Expand All @@ -1207,7 +1208,7 @@ def remove_model(old_ckpt_name):
)

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if weighting is not None:
loss = loss * weighting
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -602,7 +602,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand Down
Loading

0 comments on commit 420a180

Please sign in to comment.