From 0c2273e6768e36e6477503ab416731f761384a62 Mon Sep 17 00:00:00 2001 From: ekurtulus Date: Sun, 1 Oct 2023 03:54:00 +0300 Subject: [PATCH] tied augment implemented --- timm/data/loader.py | 36 ++++++++++++++++++++---- timm/data/mixup.py | 14 ++++++++-- timm/loss/__init__.py | 1 + timm/models/__init__.py | 1 + timm/utils/model.py | 3 ++ train.py | 61 +++++++++++++++++++++++++++++++---------- 6 files changed, 93 insertions(+), 23 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 7020deb7f7..8192a77256 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -14,6 +14,7 @@ import torch import torch.utils.data +from torch.utils.data import Dataset import numpy as np from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -87,7 +88,9 @@ def __init__( re_prob=0., re_mode='const', re_count=1, - re_num_splits=0): + re_num_splits=0, + tied_weight=None, + ): mean = adapt_to_chs(mean, channels) std = adapt_to_chs(std, channels) @@ -125,13 +128,16 @@ def __iter__(self): stream_context = suppress for next_input, next_target in self.loader: + def _preproccess(image): + image = image.to(device=self.device, non_blocking=True) + image = image.to(self.img_dtype).sub_(self.mean).div_(self.std) + if self.random_erasing is not None: + image = self.random_erasing(image) + return image with stream_context(): - next_input = next_input.to(device=self.device, non_blocking=True) + next_input = _preproccess(next_input) next_target = next_target.to(device=self.device, non_blocking=True) - next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) - if self.random_erasing is not None: - next_input = self.random_erasing(next_input) if not first: yield input, target @@ -185,6 +191,19 @@ def _worker_init(worker_id, worker_seeding='all'): if worker_seeding == 'all': np.random.seed(worker_info.seed % (2 ** 32 - 1)) +class TiedDatasetWrapper(Dataset): + def __init__(self, dataset): + self.dataset = dataset + self.transform = self.dataset.transform + self.dataset.transform = None + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + input, target = self.dataset[idx] + input1, input2 = self.transform(input), self.transform(input) + return (input1, input2), target def create_loader( dataset, @@ -205,6 +224,7 @@ def create_loader( auto_augment=None, num_aug_repeats=0, num_aug_splits=0, + tied_weight=None, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -255,6 +275,9 @@ def create_loader( # are correct before worker processes are launched dataset.set_loader_cfg(num_workers=num_workers) + if tied_weight is not None: + dataset = TiedDatasetWrapper(dataset) + sampler = None if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: @@ -305,7 +328,8 @@ def create_loader( re_prob=prefetch_re_prob, re_mode=re_mode, re_count=re_count, - re_num_splits=re_num_splits + re_num_splits=re_num_splits, + tied_weight=tied_weight, ) return loader diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 26dc239152..ada4a20a22 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -102,7 +102,7 @@ class Mixup: num_classes (int): number of classes for target """ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, - mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): + mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000, tied_augment=False): self.mixup_alpha = mixup_alpha self.cutmix_alpha = cutmix_alpha self.cutmix_minmax = cutmix_minmax @@ -117,6 +117,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0 self.mode = mode self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) + self.tied_augment = tied_augment def _params_per_elem(self, batch_size): lam = np.ones(batch_size, dtype=np.float32) @@ -206,7 +207,7 @@ def _mix_batch(self, x): x.mul_(lam).add_(x_flipped) return lam - def __call__(self, x, target): + def handle_mixup(self, x, target): assert len(x) % 2 == 0, 'Batch size should be even when using this' if self.mode == 'elem': lam = self._mix_elem(x) @@ -215,8 +216,15 @@ def __call__(self, x, target): else: lam = self._mix_batch(x) target = mixup_target(target, self.num_classes, lam, self.label_smoothing) - return x, target + return x, target, lam + def __call__(self, x, target): + if self.tied_augment: + input1, input2 = x + input1, mixup_target, lam = self.handle_mixup(input1, target.clone()) + return (input1, input2), (mixup_target, one_hot(target, self.num_classes), lam) + + return self.handle_mixup(x, target) class FastCollateMixup(Mixup): """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index ea7f15f2f7..cfdecf128a 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -2,3 +2,4 @@ from .binary_cross_entropy import BinaryCrossEntropy from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from .jsd import JsdCrossEntropy +from .tied_loss import TiedLoss diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0eb9561d54..63e62e2a2f 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,3 +1,4 @@ +from .tied_model import TiedModelWrapper from .beit import * from .byoanet import * from .byobnet import * diff --git a/timm/utils/model.py b/timm/utils/model.py index 894453a856..9065c1d02f 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -10,12 +10,15 @@ from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\ freeze_batch_norm_2d, unfreeze_batch_norm_2d +from timm.models import TiedModelWrapper from .model_ema import ModelEma def unwrap_model(model): if isinstance(model, ModelEma): return unwrap_model(model.ema) + elif isinstance(model, TiedModelWrapper): + return model.model else: return model.module if hasattr(model, 'module') else model diff --git a/train.py b/train.py index ec344a64cd..f0d3f31365 100755 --- a/train.py +++ b/train.py @@ -32,8 +32,8 @@ from timm import utils from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm -from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy, TiedLoss +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters, TiedModelWrapper from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler @@ -288,6 +288,14 @@ help='Drop path rate (default: None)') group.add_argument('--drop-block', type=float, default=None, metavar='PCT', help='Drop block rate (default: None)') +group.add_argument('--tied-weight', type=float, default=None, + help='Tied weight enabled only if > 0. If enabled, the script uses Tied-Augment') +group.add_argument('--tied-double-supervision', action='store_true', + help='If true, both branches of tied-augment will be fed to cross entropy loss.' + 'Generally it is advised to set it to true.') +group.add_argument('--tied-single-pass', action='store_true', + help='If true, two branches of Tied-Augment will be processed by a single forward pass.' + 'This can be used to speed the training up if using high end GPUs with high memory.') # Batch norm parameters (only works with gen_efficientnet based models currently) group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') @@ -381,7 +389,7 @@ def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True - args.prefetcher = not args.no_prefetcher + args.prefetcher = not args.no_prefetcher and args.tied_weight is None args.grad_accum_steps = max(1, args.grad_accum_steps) device = utils.init_distributed_device(args) if args.distributed: @@ -486,6 +494,9 @@ def main(): 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + if args.tied_weight is not None: + model = TiedModelWrapper(model, single_forward=args.tied_single_pass) + if args.torchscript: assert not args.torchcompile assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' @@ -613,9 +624,10 @@ def main(): switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, - num_classes=args.num_classes + num_classes=args.num_classes, + tied_augment=(args.tied_weight is not None) ) - if args.prefetcher: + if args.prefetcher and args.tied_weight is None: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: @@ -649,6 +661,7 @@ def main(): num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, + tied_weight=args.tied_weight, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, @@ -699,6 +712,11 @@ def main(): train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.to(device=device) validate_loss_fn = nn.CrossEntropyLoss().to(device=device) + + if args.tied_weight is not None: + train_loss_fn = TiedLoss(train_loss_fn, + args.tied_weight, + both_supervised=args.tied_double_supervision) # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric @@ -884,15 +902,24 @@ def train_one_epoch( last_batch = batch_idx == last_batch_idx need_update = last_batch or (batch_idx + 1) % accum_steps == 0 update_idx = batch_idx // accum_steps + size = input.size(0) if args.tied_weight is None else input[0].size(0) if batch_idx >= last_batch_idx_to_accum: accum_steps = last_accum_steps if not args.prefetcher: - input, target = input.to(device), target.to(device) - if mixup_fn is not None: - input, target = mixup_fn(input, target) + if args.tied_weight is None: + input, target = input.to(device), target.to(device) + if mixup_fn is not None: + input, target = mixup_fn(input, target) + else: + input = (input[0].to(device), input[1].to(device)) + target = target.to(device) + if mixup_fn is not None: + (input1, input2), (mixup_target, normal_target, lam) = mixup_fn(input, target) + if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) + input = input.contiguous(memory_format=torch.channels_last) if args.tied_weight is None else \ + (input[0].contiguous(memory_format=torch.channels_last), input[1].contiguous(memory_format=torch.channels_last)) # multiply by accum steps to get equivalent for full update data_time_m.update(accum_steps * (time.time() - data_start_time)) @@ -900,6 +927,11 @@ def train_one_epoch( def _forward(): with amp_autocast(): output = model(input) + if mixup_fn is not None and args.tied_weight is not None: + feat1, feat2, logit1, logit2 = output + feat2 = lam * feat2 + (1 - lam) * feat2.flip(0) + output = (feat1, feat2, logit1, logit2) + target = (mixup_target, normal_target) loss = loss_fn(output, target) if accum_steps > 1: loss /= accum_steps @@ -936,8 +968,8 @@ def _backward(_loss): _backward(loss) if not args.distributed: - losses_m.update(loss.item() * accum_steps, input.size(0)) - update_sample_count += input.size(0) + losses_m.update(loss.item() * accum_steps, size) + update_sample_count += size if not need_update: data_start_time = time.time() @@ -960,7 +992,7 @@ def _backward(_loss): if args.distributed: reduced_loss = utils.reduce_tensor(loss.data, args.world_size) - losses_m.update(reduced_loss.item() * accum_steps, input.size(0)) + losses_m.update(reduced_loss.item() * accum_steps, size) update_sample_count *= args.world_size if utils.is_primary(args): @@ -1012,13 +1044,14 @@ def validate( losses_m = utils.AverageMeter() top1_m = utils.AverageMeter() top5_m = utils.AverageMeter() - + model.eval() end = time.time() last_idx = len(loader) - 1 with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): + size = input.size(0) if args.tied_weight is None else input[0].size(0) last_batch = batch_idx == last_idx if not args.prefetcher: input = input.to(device) @@ -1050,7 +1083,7 @@ def validate( if device.type == 'cuda': torch.cuda.synchronize() - losses_m.update(reduced_loss.item(), input.size(0)) + losses_m.update(reduced_loss.item(), size) top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0))