Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tied augment implemented #1974

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.utils.data
from torch.utils.data import Dataset
import numpy as np

from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
Expand Down Expand Up @@ -87,7 +88,9 @@ def __init__(
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
re_num_splits=0,
tied_weight=None,
):

mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
Expand Down Expand Up @@ -125,13 +128,16 @@ def __iter__(self):
stream_context = suppress

for next_input, next_target in self.loader:
def _preproccess(image):
image = image.to(device=self.device, non_blocking=True)
image = image.to(self.img_dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
image = self.random_erasing(image)
return image

with stream_context():
next_input = next_input.to(device=self.device, non_blocking=True)
next_input = _preproccess(next_input)
next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)

if not first:
yield input, target
Expand Down Expand Up @@ -185,6 +191,19 @@ def _worker_init(worker_id, worker_seeding='all'):
if worker_seeding == 'all':
np.random.seed(worker_info.seed % (2 ** 32 - 1))

class TiedDatasetWrapper(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.transform = self.dataset.transform
self.dataset.transform = None

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
input, target = self.dataset[idx]
input1, input2 = self.transform(input), self.transform(input)
return (input1, input2), target

def create_loader(
dataset,
Expand All @@ -205,6 +224,7 @@ def create_loader(
auto_augment=None,
num_aug_repeats=0,
num_aug_splits=0,
tied_weight=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
Expand Down Expand Up @@ -255,6 +275,9 @@ def create_loader(
# are correct before worker processes are launched
dataset.set_loader_cfg(num_workers=num_workers)

if tied_weight is not None:
dataset = TiedDatasetWrapper(dataset)

sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
Expand Down Expand Up @@ -305,7 +328,8 @@ def create_loader(
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
re_num_splits=re_num_splits,
tied_weight=tied_weight,
)

return loader
Expand Down
14 changes: 11 additions & 3 deletions timm/data/mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Mixup:
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000, tied_augment=False):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
Expand All @@ -117,6 +117,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
self.tied_augment = tied_augment

def _params_per_elem(self, batch_size):
lam = np.ones(batch_size, dtype=np.float32)
Expand Down Expand Up @@ -206,7 +207,7 @@ def _mix_batch(self, x):
x.mul_(lam).add_(x_flipped)
return lam

def __call__(self, x, target):
def handle_mixup(self, x, target):
assert len(x) % 2 == 0, 'Batch size should be even when using this'
if self.mode == 'elem':
lam = self._mix_elem(x)
Expand All @@ -215,8 +216,15 @@ def __call__(self, x, target):
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x, target
return x, target, lam

def __call__(self, x, target):
if self.tied_augment:
input1, input2 = x
input1, mixup_target, lam = self.handle_mixup(input1, target.clone())
return (input1, input2), (mixup_target, one_hot(target, self.num_classes), lam)

return self.handle_mixup(x, target)

class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
Expand Down
1 change: 1 addition & 0 deletions timm/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .binary_cross_entropy import BinaryCrossEntropy
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy
from .tied_loss import TiedLoss
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .tied_model import TiedModelWrapper
from .beit import *
from .byoanet import *
from .byobnet import *
Expand Down
3 changes: 3 additions & 0 deletions timm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@

from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
freeze_batch_norm_2d, unfreeze_batch_norm_2d
from timm.models import TiedModelWrapper
from .model_ema import ModelEma


def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
elif isinstance(model, TiedModelWrapper):
return model.model
else:
return model.module if hasattr(model, 'module') else model

Expand Down
61 changes: 47 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy, TiedLoss
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters, TiedModelWrapper
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler
Expand Down Expand Up @@ -288,6 +288,14 @@
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
group.add_argument('--tied-weight', type=float, default=None,
help='Tied weight enabled only if > 0. If enabled, the script uses Tied-Augment')
group.add_argument('--tied-double-supervision', action='store_true',
help='If true, both branches of tied-augment will be fed to cross entropy loss.'
'Generally it is advised to set it to true.')
group.add_argument('--tied-single-pass', action='store_true',
help='If true, two branches of Tied-Augment will be processed by a single forward pass.'
'This can be used to speed the training up if using high end GPUs with high memory.')

# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
Expand Down Expand Up @@ -381,7 +389,7 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

args.prefetcher = not args.no_prefetcher
args.prefetcher = not args.no_prefetcher and args.tied_weight is None
args.grad_accum_steps = max(1, args.grad_accum_steps)
device = utils.init_distributed_device(args)
if args.distributed:
Expand Down Expand Up @@ -486,6 +494,9 @@ def main():
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

if args.tied_weight is not None:
model = TiedModelWrapper(model, single_forward=args.tied_single_pass)

if args.torchscript:
assert not args.torchcompile
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
Expand Down Expand Up @@ -613,9 +624,10 @@ def main():
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
num_classes=args.num_classes,
tied_augment=(args.tied_weight is not None)
)
if args.prefetcher:
if args.prefetcher and args.tied_weight is None:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
Expand Down Expand Up @@ -649,6 +661,7 @@ def main():
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
tied_weight=args.tied_weight,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
Expand Down Expand Up @@ -699,6 +712,11 @@ def main():
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)

if args.tied_weight is not None:
train_loss_fn = TiedLoss(train_loss_fn,
args.tied_weight,
both_supervised=args.tied_double_supervision)

# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
Expand Down Expand Up @@ -884,22 +902,36 @@ def train_one_epoch(
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
update_idx = batch_idx // accum_steps
size = input.size(0) if args.tied_weight is None else input[0].size(0)
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps

if not args.prefetcher:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.tied_weight is None:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
else:
input = (input[0].to(device), input[1].to(device))
target = target.to(device)
if mixup_fn is not None:
(input1, input2), (mixup_target, normal_target, lam) = mixup_fn(input, target)

if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
input = input.contiguous(memory_format=torch.channels_last) if args.tied_weight is None else \
(input[0].contiguous(memory_format=torch.channels_last), input[1].contiguous(memory_format=torch.channels_last))

# multiply by accum steps to get equivalent for full update
data_time_m.update(accum_steps * (time.time() - data_start_time))

def _forward():
with amp_autocast():
output = model(input)
if mixup_fn is not None and args.tied_weight is not None:
feat1, feat2, logit1, logit2 = output
feat2 = lam * feat2 + (1 - lam) * feat2.flip(0)
output = (feat1, feat2, logit1, logit2)
target = (mixup_target, normal_target)
loss = loss_fn(output, target)
if accum_steps > 1:
loss /= accum_steps
Expand Down Expand Up @@ -936,8 +968,8 @@ def _backward(_loss):
_backward(loss)

if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
losses_m.update(loss.item() * accum_steps, size)
update_sample_count += size

if not need_update:
data_start_time = time.time()
Expand All @@ -960,7 +992,7 @@ def _backward(_loss):

if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
losses_m.update(reduced_loss.item() * accum_steps, size)
update_sample_count *= args.world_size

if utils.is_primary(args):
Expand Down Expand Up @@ -1012,13 +1044,14 @@ def validate(
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()

model.eval()

end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
size = input.size(0) if args.tied_weight is None else input[0].size(0)
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
Expand Down Expand Up @@ -1050,7 +1083,7 @@ def validate(
if device.type == 'cuda':
torch.cuda.synchronize()

losses_m.update(reduced_loss.item(), input.size(0))
losses_m.update(reduced_loss.item(), size)
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))

Expand Down
Loading