From 475a6043127a2a03bdeb31338112c5f63657719b Mon Sep 17 00:00:00 2001 From: Tanya Date: Mon, 16 May 2022 22:18:19 +0530 Subject: [PATCH 1/2] initial commit --- src/args.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/args.py b/src/args.py index 538c65a..00f12ba 100644 --- a/src/args.py +++ b/src/args.py @@ -5,7 +5,9 @@ def get_args_parser(): parser = argparse.ArgumentParser('Set transformer detector', add_help=False) parser.add_argument('--lr', default=1e-4, type=float) parser.add_argument('--lr_backbone', default=1e-5, type=float) - parser.add_argument('--batch_size', default=2, type=int) + parser.add_argument('--batch_size', default=2, type=int) + parser.add_argument('--grad_accum_batches', default=1, type=int, + help="Number of batches to accumulate using gradient accumulation") parser.add_argument('--weight_decay', default=1e-4, type=float) parser.add_argument('--epochs', default=300, type=int) parser.add_argument('--lr_drop', default=200, type=int) @@ -115,4 +117,4 @@ def get_args_parser(): parser.add_argument('--eval', action='store_true') parser.add_argument('--dataset', default='train', type=str, choices=('train', 'val')) - return parser \ No newline at end of file + return parser From 31b02dfc20ad90aacfbb484b6231d548ad69db3b Mon Sep 17 00:00:00 2001 From: Tanya Date: Mon, 16 May 2022 22:18:54 +0530 Subject: [PATCH 2/2] Update engine.py --- src/engine.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/engine.py b/src/engine.py index 7cf71d4..7b0137e 100644 --- a/src/engine.py +++ b/src/engine.py @@ -25,6 +25,7 @@ def train_one_epoch(model, criterion, postprocessors, data_loader, optimizer, de counter = 0 + batch_idx = 0 torch.cuda.empty_cache() for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device) @@ -59,14 +60,18 @@ def train_one_epoch(model, criterion, postprocessors, data_loader, optimizer, de print(loss_dict_reduced) sys.exit(1) - optimizer.zero_grad() + losses /= args.grad_accum_batches losses.backward() - if max_norm > 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - optimizer.step() + + if (batch_idx + 1) % args.grad_accum_batches == 0: + if args.clip_max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_max_norm) + optimizer.step() + optimizer.zero_grad() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + batch_idx +=1 # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) @@ -163,4 +168,4 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out # accumulate predictions from all images stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} - return stats \ No newline at end of file + return stats