Skip to content

Commit

Permalink
Merge pull request #8 from ramanakumars/mae_loss
Browse files Browse the repository at this point in the history
Adding MAE loss and more customizability to discriminator
  • Loading branch information
AgentM-GEG authored Oct 13, 2023
2 parents 24befbf + bedb441 commit 5d7866b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 18 deletions.
8 changes: 5 additions & 3 deletions patchgan/disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class Discriminator(nn.Module, Transferable):
"""Defines a PatchGAN discriminator"""

def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.InstanceNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
Expand All @@ -27,17 +27,19 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=False),
nn.Tanh(),
norm_layer(ndf * nf_mult)
]
if norm:
sequence += [norm_layer(ndf * nf_mult)]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=False),
nn.Tanh(),
norm_layer(ndf * nf_mult)
]
if norm:
sequence += [norm_layer(ndf * nf_mult)]

# output 1 channel prediction map
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw,
Expand Down
4 changes: 4 additions & 0 deletions patchgan/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,9 @@ def fc_tversky(y_true, y_pred, beta, gamma=0.75, batch_mean=True):
return torch.pow(focal_tversky_loss, gamma)


def MAE_loss(y_true, y_pred):
return torch.mean(torch.abs(y_true - y_pred))


# alias
bce_loss = nn.BCELoss()
23 changes: 13 additions & 10 deletions patchgan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ def patchgan_train():
datagen = Dataset(data_paths['images'], data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs)
train_datagen, val_datagen = random_split(datagen, train_val_split)

model_params = config['model_params']
gen_filts = model_params['gen_filts']
disc_filts = model_params['disc_filts']
n_disc_layers = model_params['n_disc_layers']
activation = model_params['activation']
use_dropout = model_params.get('use_dropout', True)
final_activation = model_params.get('final_activation', 'sigmoid')

dloader_kwargs = {}
if args.dataloader_workers > 0:
dloader_kwargs['num_workers'] = args.dataloader_workers
Expand All @@ -90,14 +82,25 @@ def patchgan_train():
train_data = DataLoader(train_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs)
val_data = DataLoader(val_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs)

model_params = config['model_params']
generator_config = model_params['generator']
discriminator_config = model_params['discriminator']

# create the generator
gen_filts = generator_config['filters']
activation = generator_config['activation']
use_dropout = generator_config.get('use_dropout', True)
final_activation = generator_config.get('final_activation', 'sigmoid')
generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device)

# create the discriminator
discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device)
disc_filts = discriminator_config['filters']
disc_norm = discriminator_config.get('norm', False)
n_disc_layers = discriminator_config['n_layers']
discriminator = Discriminator(in_channels + out_channels, disc_filts, norm=disc_norm, n_layers=n_disc_layers).to(device)

if args.summary:
summary(generator, [1, in_channels, size, size])
summary(generator, [1, in_channels, size, size], depth=4)
summary(discriminator, [1, in_channels + out_channels, size, size])

checkpoint_path = config.get('checkpoint_path', './checkpoints/')
Expand Down
8 changes: 5 additions & 3 deletions patchgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from torch import optim
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from .losses import fc_tversky, bce_loss
from .losses import fc_tversky, bce_loss, MAE_loss
from torch.nn.functional import binary_cross_entropy
from collections import defaultdict

Expand Down Expand Up @@ -78,6 +78,8 @@ def batch(self, x, y, train=False):
else:
weight = torch.ones_like(target_tensor)
gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.seg_alpha
elif self.loss_type == 'MAE':
gen_loss = MAE_loss(gen_img, target_tensor) * self.seg_alpha

gen_loss_disc = bce_loss(disc_fake, labels_real)
gen_loss = gen_loss + gen_loss_disc
Expand Down Expand Up @@ -164,9 +166,9 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-3,
self.neptune_config['model/parameters/n_epochs'] = epochs

# create the Adam optimzers
self.gen_optimizer = optim.NAdam(
self.gen_optimizer = optim.Adam(
self.generator.parameters(), lr=gen_lr, betas=(0.9, 0.999))
self.disc_optimizer = optim.NAdam(
self.disc_optimizer = optim.Adam(
self.discriminator.parameters(), lr=dsc_lr, betas=(0.9, 0.999))

# set up the learning rate scheduler with exponential lr decay
Expand Down
5 changes: 3 additions & 2 deletions patchgan/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, use_d
downnorm = norm_layer(output_filt)

enc_sub = OrderedDict([(f'DownConv{layer}', downconv),
(f'DownAct{layer}', activation),
(f'DownNorm{layer}', downnorm),
(f'DownAct{layer}', activation),
])
if use_dropout:
enc_sub = OrderedDict(chain(enc_sub.items(),
Expand Down Expand Up @@ -54,8 +54,9 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, batch
if batch_norm:
upnorm = norm_layer(output_filt)
dec_sub = OrderedDict([(f'UpConv{layer}', upconv),
(f'UpNorm{layer}', upnorm),
(f'UpAct{layer}', activation),
(f'UpNorm{layer}', upnorm)])
])
else:
dec_sub = OrderedDict([(f'UpConv{layer}', upconv),
(f'UpAct{layer}', activation)])
Expand Down

0 comments on commit 5d7866b

Please sign in to comment.