Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pseudo Huber loss for Flux and SD3 #1808

Merged
merged 3 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -397,7 +397,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand All @@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
recris marked this conversation as resolved.
Show resolved Hide resolved
)

accelerator.backward(loss)
Expand Down
2 changes: 1 addition & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor):

# calculate loss
loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if weighting is not None:
loss = loss * weighting
Expand Down
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, None, weighting
return model_pred, target, timesteps, weighting

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
Expand Down
74 changes: 42 additions & 32 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)

parser.add_argument(
"--huber_scale",
type=float,
default=1.0,
help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)

parser.add_argument(
Expand Down Expand Up @@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)


def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")

if args.loss_type == "huber" or args.loss_type == "smooth_l1":
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = torch.exp(-alpha * timesteps)
elif args.huber_schedule == "snr":
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = torch.full((b_size,), args.huber_c)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
huber_c = huber_c.to(device)
elif args.loss_type == "l2":
huber_c = None # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")

timesteps = timesteps.long().to(device)
return timesteps, huber_c
def get_timesteps(min_timestep, max_timestep, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
timesteps = timesteps.long()
recris marked this conversation as resolved.
Show resolved Hide resolved
return timesteps


def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
Expand All @@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand All @@ -5878,32 +5866,54 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

return noise, noisy_latents, timesteps, huber_c
return noise, noisy_latents, timesteps


def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
b_size = timesteps.shape[0]
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, 'alphas_cumprod'):
raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")

return result


def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
):
if loss_type == "l2":
if args.loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
elif args.loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
elif args.loss_type == "huber":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
elif args.loss_type == "smooth_l1":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
return loss


Expand Down
2 changes: 1 addition & 1 deletion sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor):
# )
# calculate loss
loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler
recris marked this conversation as resolved.
Show resolved Hide resolved
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand Down
2 changes: 1 addition & 1 deletion sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_noise_pred_and_target(

target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, None, weighting
return model_pred, target, timesteps, weighting

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
Expand Down
6 changes: 3 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def optimizer_hook(parameter: torch.Tensor):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -720,7 +720,7 @@ def optimizer_hook(parameter: torch.Tensor):
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand All @@ -738,7 +738,7 @@ def optimizer_hook(parameter: torch.Tensor):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)

accelerator.backward(loss)
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -534,7 +534,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -485,7 +485,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 4 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -426,7 +426,9 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
6 changes: 3 additions & 3 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
timesteps = train_util.get_timesteps(
0, noise_scheduler.config.num_train_timesteps, b_size, latents.device
)

# Add noise to the latents according to the noise magnitude at each timestep
Expand Down Expand Up @@ -499,7 +499,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3])

Expand Down
4 changes: 2 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def train(args):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -385,7 +385,7 @@ def train(args):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand Down
9 changes: 5 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def get_noise_pred_and_target(
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# ensure the hidden state will require grad
if args.gradient_checkpointing:
Expand Down Expand Up @@ -244,7 +244,7 @@ def get_noise_pred_and_target(
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)

return noise_pred, target, timesteps, huber_c, None
return noise_pred, target, timesteps, None

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
if args.min_snr_gamma:
Expand Down Expand Up @@ -806,6 +806,7 @@ def load_model_hook(models, input_dir):
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_scale": args.huber_scale,
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
Expand Down Expand Up @@ -1193,7 +1194,7 @@ def remove_model(old_ckpt_name):
text_encoder_conds[i] = encoded_text_encoder_conds[i]

# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
Expand All @@ -1207,7 +1208,7 @@ def remove_model(old_ckpt_name):
)

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if weighting is not None:
loss = loss * weighting
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -602,7 +602,7 @@ def remove_model(old_ckpt_name):
target = noise

loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
Expand Down
Loading