Skip to content

Commit

Permalink
Merge pull request #934 from feffy380/fix-minsnr-vpred-zsnr
Browse files Browse the repository at this point in the history
Fix min-snr-gamma for v-prediction and ZSNR.
  • Loading branch information
kohya-ss authored Nov 25, 2023
2 parents 0fb9ecf + 6b3148f commit fc8649d
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down
9 changes: 6 additions & 3 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ def enforce_zero_terminal_snr(betas):
noise_scheduler.alphas_cumprod = alphas_cumprod


def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss

Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def remove_model(old_ckpt_name):
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def remove_model(old_ckpt_name):
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
Expand Down
2 changes: 1 addition & 1 deletion train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def remove_model(old_ckpt_name):
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def train(args):
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def remove_model(old_ckpt_name):
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def remove_model(old_ckpt_name):
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def remove_model(old_ckpt_name):

loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down

0 comments on commit fc8649d

Please sign in to comment.