From 75e5bd6442ad773f8fb27504ef8ed711e383217f Mon Sep 17 00:00:00 2001 From: maxin-cn Date: Mon, 9 Sep 2024 16:21:09 +1000 Subject: [PATCH] support gradient accumulation --- train.py | 10 ++++++---- train_with_img.py | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 11193d3..5bb9370 100644 --- a/train.py +++ b/train.py @@ -219,7 +219,7 @@ def main(args): t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) loss_dict = diffusion.training_losses(model, x, t, model_kwargs) - loss = loss_dict["loss"].mean() + loss = loss_dict["loss"].mean() / args.gradient_accumulation_steps loss.backward() if train_steps < args.start_clip_iter: # if train_steps >= start_clip_iter, will clip gradient @@ -227,10 +227,12 @@ def main(args): else: gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=True) - opt.step() + lr_scheduler.step() - opt.zero_grad() - update_ema(ema, model.module) + if train_steps % args.gradient_accumulation_steps == 0 and train_steps > 0: + opt.step() + opt.zero_grad() + update_ema(ema, model.module) # Log loss values: running_loss += loss.item() diff --git a/train_with_img.py b/train_with_img.py index d3f8685..ea164e1 100644 --- a/train_with_img.py +++ b/train_with_img.py @@ -239,7 +239,7 @@ def main(args): t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) loss_dict = diffusion.training_losses(model, x, t, model_kwargs) - loss = loss_dict["loss"].mean() + loss = loss_dict["loss"].mean() / args.gradient_accumulation_steps loss.backward() if train_steps < args.start_clip_iter: # if train_steps >= start_clip_iter, will clip gradient @@ -247,10 +247,11 @@ def main(args): else: gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=True) - opt.step() lr_scheduler.step() - opt.zero_grad() - update_ema(ema, model.module) + if train_steps % args.gradient_accumulation_steps == 0 and train_steps > 0: + opt.step() + opt.zero_grad() + update_ema(ema, model.module) # Log loss values: running_loss += loss.item()