Skip to content

Commit

Permalink
tied augment implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
ekurtulus committed Oct 1, 2023
1 parent 054c763 commit 0c2273e
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 23 deletions.
36 changes: 30 additions & 6 deletions timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.utils.data
from torch.utils.data import Dataset
import numpy as np

from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
Expand Down Expand Up @@ -87,7 +88,9 @@ def __init__(
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
re_num_splits=0,
tied_weight=None,
):

mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
Expand Down Expand Up @@ -125,13 +128,16 @@ def __iter__(self):
stream_context = suppress

for next_input, next_target in self.loader:
def _preproccess(image):
image = image.to(device=self.device, non_blocking=True)
image = image.to(self.img_dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
image = self.random_erasing(image)
return image

with stream_context():
next_input = next_input.to(device=self.device, non_blocking=True)
next_input = _preproccess(next_input)
next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)

if not first:
yield input, target
Expand Down Expand Up @@ -185,6 +191,19 @@ def _worker_init(worker_id, worker_seeding='all'):
if worker_seeding == 'all':
np.random.seed(worker_info.seed % (2 ** 32 - 1))

class TiedDatasetWrapper(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.transform = self.dataset.transform
self.dataset.transform = None

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
input, target = self.dataset[idx]
input1, input2 = self.transform(input), self.transform(input)
return (input1, input2), target

def create_loader(
dataset,
Expand All @@ -205,6 +224,7 @@ def create_loader(
auto_augment=None,
num_aug_repeats=0,
num_aug_splits=0,
tied_weight=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
Expand Down Expand Up @@ -255,6 +275,9 @@ def create_loader(
# are correct before worker processes are launched
dataset.set_loader_cfg(num_workers=num_workers)

if tied_weight is not None:
dataset = TiedDatasetWrapper(dataset)

sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
Expand Down Expand Up @@ -305,7 +328,8 @@ def create_loader(
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
re_num_splits=re_num_splits,
tied_weight=tied_weight,
)

return loader
Expand Down
14 changes: 11 additions & 3 deletions timm/data/mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Mixup:
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000, tied_augment=False):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
Expand All @@ -117,6 +117,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
self.tied_augment = tied_augment

def _params_per_elem(self, batch_size):
lam = np.ones(batch_size, dtype=np.float32)
Expand Down Expand Up @@ -206,7 +207,7 @@ def _mix_batch(self, x):
x.mul_(lam).add_(x_flipped)
return lam

def __call__(self, x, target):
def handle_mixup(self, x, target):
assert len(x) % 2 == 0, 'Batch size should be even when using this'
if self.mode == 'elem':
lam = self._mix_elem(x)
Expand All @@ -215,8 +216,15 @@ def __call__(self, x, target):
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x, target
return x, target, lam

def __call__(self, x, target):
if self.tied_augment:
input1, input2 = x
input1, mixup_target, lam = self.handle_mixup(input1, target.clone())
return (input1, input2), (mixup_target, one_hot(target, self.num_classes), lam)

return self.handle_mixup(x, target)

class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
Expand Down
1 change: 1 addition & 0 deletions timm/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .binary_cross_entropy import BinaryCrossEntropy
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy
from .tied_loss import TiedLoss
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .tied_model import TiedModelWrapper
from .beit import *
from .byoanet import *
from .byobnet import *
Expand Down
3 changes: 3 additions & 0 deletions timm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@

from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
freeze_batch_norm_2d, unfreeze_batch_norm_2d
from timm.models import TiedModelWrapper
from .model_ema import ModelEma


def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
elif isinstance(model, TiedModelWrapper):
return model.model
else:
return model.module if hasattr(model, 'module') else model

Expand Down
61 changes: 47 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy, TiedLoss
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters, TiedModelWrapper
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler
Expand Down Expand Up @@ -288,6 +288,14 @@
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
group.add_argument('--tied-weight', type=float, default=None,
help='Tied weight enabled only if > 0. If enabled, the script uses Tied-Augment')
group.add_argument('--tied-double-supervision', action='store_true',
help='If true, both branches of tied-augment will be fed to cross entropy loss.'
'Generally it is advised to set it to true.')
group.add_argument('--tied-single-pass', action='store_true',
help='If true, two branches of Tied-Augment will be processed by a single forward pass.'
'This can be used to speed the training up if using high end GPUs with high memory.')

# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
Expand Down Expand Up @@ -381,7 +389,7 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

args.prefetcher = not args.no_prefetcher
args.prefetcher = not args.no_prefetcher and args.tied_weight is None
args.grad_accum_steps = max(1, args.grad_accum_steps)
device = utils.init_distributed_device(args)
if args.distributed:
Expand Down Expand Up @@ -486,6 +494,9 @@ def main():
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

if args.tied_weight is not None:
model = TiedModelWrapper(model, single_forward=args.tied_single_pass)

if args.torchscript:
assert not args.torchcompile
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
Expand Down Expand Up @@ -613,9 +624,10 @@ def main():
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
num_classes=args.num_classes,
tied_augment=(args.tied_weight is not None)
)
if args.prefetcher:
if args.prefetcher and args.tied_weight is None:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
Expand Down Expand Up @@ -649,6 +661,7 @@ def main():
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
tied_weight=args.tied_weight,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
Expand Down Expand Up @@ -699,6 +712,11 @@ def main():
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)

if args.tied_weight is not None:
train_loss_fn = TiedLoss(train_loss_fn,
args.tied_weight,
both_supervised=args.tied_double_supervision)

# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
Expand Down Expand Up @@ -884,22 +902,36 @@ def train_one_epoch(
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
update_idx = batch_idx // accum_steps
size = input.size(0) if args.tied_weight is None else input[0].size(0)
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps

if not args.prefetcher:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.tied_weight is None:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
else:
input = (input[0].to(device), input[1].to(device))
target = target.to(device)
if mixup_fn is not None:
(input1, input2), (mixup_target, normal_target, lam) = mixup_fn(input, target)

if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
input = input.contiguous(memory_format=torch.channels_last) if args.tied_weight is None else \
(input[0].contiguous(memory_format=torch.channels_last), input[1].contiguous(memory_format=torch.channels_last))

# multiply by accum steps to get equivalent for full update
data_time_m.update(accum_steps * (time.time() - data_start_time))

def _forward():
with amp_autocast():
output = model(input)
if mixup_fn is not None and args.tied_weight is not None:
feat1, feat2, logit1, logit2 = output
feat2 = lam * feat2 + (1 - lam) * feat2.flip(0)
output = (feat1, feat2, logit1, logit2)
target = (mixup_target, normal_target)
loss = loss_fn(output, target)
if accum_steps > 1:
loss /= accum_steps
Expand Down Expand Up @@ -936,8 +968,8 @@ def _backward(_loss):
_backward(loss)

if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
losses_m.update(loss.item() * accum_steps, size)
update_sample_count += size

if not need_update:
data_start_time = time.time()
Expand All @@ -960,7 +992,7 @@ def _backward(_loss):

if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
losses_m.update(reduced_loss.item() * accum_steps, size)
update_sample_count *= args.world_size

if utils.is_primary(args):
Expand Down Expand Up @@ -1012,13 +1044,14 @@ def validate(
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()

model.eval()

end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
size = input.size(0) if args.tied_weight is None else input[0].size(0)
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
Expand Down Expand Up @@ -1050,7 +1083,7 @@ def validate(
if device.type == 'cuda':
torch.cuda.synchronize()

losses_m.update(reduced_loss.item(), input.size(0))
losses_m.update(reduced_loss.item(), size)
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))

Expand Down

0 comments on commit 0c2273e

Please sign in to comment.