Skip to content

Commit

Permalink
support gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
maxin-cn committed Sep 9, 2024
1 parent c598ddd commit 75e5bd6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
10 changes: 6 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,20 @@ def main(args):

t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
loss = loss_dict["loss"].mean()
loss = loss_dict["loss"].mean() / args.gradient_accumulation_steps
loss.backward()

if train_steps < args.start_clip_iter: # if train_steps >= start_clip_iter, will clip gradient
gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=False)
else:
gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=True)

opt.step()

lr_scheduler.step()
opt.zero_grad()
update_ema(ema, model.module)
if train_steps % args.gradient_accumulation_steps == 0 and train_steps > 0:
opt.step()
opt.zero_grad()
update_ema(ema, model.module)

# Log loss values:
running_loss += loss.item()
Expand Down
9 changes: 5 additions & 4 deletions train_with_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,19 @@ def main(args):

t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
loss = loss_dict["loss"].mean()
loss = loss_dict["loss"].mean() / args.gradient_accumulation_steps
loss.backward()

if train_steps < args.start_clip_iter: # if train_steps >= start_clip_iter, will clip gradient
gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=False)
else:
gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=True)

opt.step()
lr_scheduler.step()
opt.zero_grad()
update_ema(ema, model.module)
if train_steps % args.gradient_accumulation_steps == 0 and train_steps > 0:
opt.step()
opt.zero_grad()
update_ema(ema, model.module)

# Log loss values:
running_loss += loss.item()
Expand Down

0 comments on commit 75e5bd6

Please sign in to comment.