diff --git a/patchgan/disc.py b/patchgan/disc.py index b6f1ea6..83f48ae 100644 --- a/patchgan/disc.py +++ b/patchgan/disc.py @@ -5,7 +5,7 @@ class Discriminator(nn.Module, Transferable): """Defines a PatchGAN discriminator""" - def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): + def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.InstanceNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images @@ -27,8 +27,9 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=False), nn.Tanh(), - norm_layer(ndf * nf_mult) ] + if norm: + sequence += [norm_layer(ndf * nf_mult)] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) @@ -36,8 +37,9 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=False), nn.Tanh(), - norm_layer(ndf * nf_mult) ] + if norm: + sequence += [norm_layer(ndf * nf_mult)] # output 1 channel prediction map sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, diff --git a/patchgan/losses.py b/patchgan/losses.py index d120ce6..31ce19b 100644 --- a/patchgan/losses.py +++ b/patchgan/losses.py @@ -31,5 +31,9 @@ def fc_tversky(y_true, y_pred, beta, gamma=0.75, batch_mean=True): return torch.pow(focal_tversky_loss, gamma) +def MAE_loss(y_true, y_pred): + return torch.mean(torch.abs(y_true - y_pred)) + + # alias bce_loss = nn.BCELoss() diff --git a/patchgan/train.py b/patchgan/train.py index 5cd3ff2..7b846a5 100644 --- a/patchgan/train.py +++ b/patchgan/train.py @@ -74,14 +74,6 @@ def patchgan_train(): datagen = Dataset(data_paths['images'], data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs) train_datagen, val_datagen = random_split(datagen, train_val_split) - model_params = config['model_params'] - gen_filts = model_params['gen_filts'] - disc_filts = model_params['disc_filts'] - n_disc_layers = model_params['n_disc_layers'] - activation = model_params['activation'] - use_dropout = model_params.get('use_dropout', True) - final_activation = model_params.get('final_activation', 'sigmoid') - dloader_kwargs = {} if args.dataloader_workers > 0: dloader_kwargs['num_workers'] = args.dataloader_workers @@ -90,14 +82,25 @@ def patchgan_train(): train_data = DataLoader(train_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs) val_data = DataLoader(val_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs) + model_params = config['model_params'] + generator_config = model_params['generator'] + discriminator_config = model_params['discriminator'] + # create the generator + gen_filts = generator_config['filters'] + activation = generator_config['activation'] + use_dropout = generator_config.get('use_dropout', True) + final_activation = generator_config.get('final_activation', 'sigmoid') generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device) # create the discriminator - discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device) + disc_filts = discriminator_config['filters'] + disc_norm = discriminator_config.get('norm', False) + n_disc_layers = discriminator_config['n_layers'] + discriminator = Discriminator(in_channels + out_channels, disc_filts, norm=disc_norm, n_layers=n_disc_layers).to(device) if args.summary: - summary(generator, [1, in_channels, size, size]) + summary(generator, [1, in_channels, size, size], depth=4) summary(discriminator, [1, in_channels + out_channels, size, size]) checkpoint_path = config.get('checkpoint_path', './checkpoints/') diff --git a/patchgan/trainer.py b/patchgan/trainer.py index e6b05ce..394d44b 100644 --- a/patchgan/trainer.py +++ b/patchgan/trainer.py @@ -5,7 +5,7 @@ import numpy as np from torch import optim from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau -from .losses import fc_tversky, bce_loss +from .losses import fc_tversky, bce_loss, MAE_loss from torch.nn.functional import binary_cross_entropy from collections import defaultdict @@ -78,6 +78,8 @@ def batch(self, x, y, train=False): else: weight = torch.ones_like(target_tensor) gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.seg_alpha + elif self.loss_type == 'MAE': + gen_loss = MAE_loss(gen_img, target_tensor) * self.seg_alpha gen_loss_disc = bce_loss(disc_fake, labels_real) gen_loss = gen_loss + gen_loss_disc @@ -164,9 +166,9 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-3, self.neptune_config['model/parameters/n_epochs'] = epochs # create the Adam optimzers - self.gen_optimizer = optim.NAdam( + self.gen_optimizer = optim.Adam( self.generator.parameters(), lr=gen_lr, betas=(0.9, 0.999)) - self.disc_optimizer = optim.NAdam( + self.disc_optimizer = optim.Adam( self.discriminator.parameters(), lr=dsc_lr, betas=(0.9, 0.999)) # set up the learning rate scheduler with exponential lr decay diff --git a/patchgan/unet.py b/patchgan/unet.py index dd55fd4..4805ce1 100755 --- a/patchgan/unet.py +++ b/patchgan/unet.py @@ -20,8 +20,8 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, use_d downnorm = norm_layer(output_filt) enc_sub = OrderedDict([(f'DownConv{layer}', downconv), - (f'DownAct{layer}', activation), (f'DownNorm{layer}', downnorm), + (f'DownAct{layer}', activation), ]) if use_dropout: enc_sub = OrderedDict(chain(enc_sub.items(), @@ -54,8 +54,9 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, batch if batch_norm: upnorm = norm_layer(output_filt) dec_sub = OrderedDict([(f'UpConv{layer}', upconv), + (f'UpNorm{layer}', upnorm), (f'UpAct{layer}', activation), - (f'UpNorm{layer}', upnorm)]) + ]) else: dec_sub = OrderedDict([(f'UpConv{layer}', upconv), (f'UpAct{layer}', activation)])