Skip to content

Commit

Permalink
Merge branch 'pr/1228' into master3
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Apr 3, 2024
2 parents 3cfeebd + 47fb1a6 commit aa80cc8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
49 changes: 33 additions & 16 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3402,14 +3402,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "huber_scheduled"],
choices=["l2", "huber", "smooth_l1"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_schedule",
type=str,
default="exponential",
choices=["constant", "exponential", "snr"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes is selected with loss_type.",
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
)
parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -5141,26 +5148,30 @@ def save_sd_model_on_train_end_common(
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, b_size, device):
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):

#TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.loss_type == 'huber_scheduled':

if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
timestep = timesteps.item()

alpha = - math.log(args.huber_c) / max_timestep
huber_c = math.exp(-alpha * timestep)
timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == 'huber':
# for fairness in comparison
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
if args.huber_schedule == "exponential":
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')

timesteps = timesteps.repeat(b_size).to(device)
huber_c = args.huber_c
elif args.loss_type == 'l2':
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
Expand Down Expand Up @@ -5189,7 +5200,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, b_size, latents.device)
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand All @@ -5209,8 +5220,14 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str

if loss_type == 'l2':
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == 'huber' or loss_type == 'huber_scheduled':
loss = huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
elif loss_type == 'huber':
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == 'smooth_l1':
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
Expand Down
2 changes: 1 addition & 1 deletion train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device)
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down

0 comments on commit aa80cc8

Please sign in to comment.