Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MAE loss and more customizability to discriminator #8

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions patchgan/disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class Discriminator(nn.Module, Transferable):
"""Defines a PatchGAN discriminator"""

def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.InstanceNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
Expand All @@ -27,17 +27,19 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=False),
nn.Tanh(),
norm_layer(ndf * nf_mult)
]
if norm:
sequence += [norm_layer(ndf * nf_mult)]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=False),
nn.Tanh(),
norm_layer(ndf * nf_mult)
]
if norm:
sequence += [norm_layer(ndf * nf_mult)]

# output 1 channel prediction map
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw,
Expand Down
4 changes: 4 additions & 0 deletions patchgan/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,9 @@ def fc_tversky(y_true, y_pred, beta, gamma=0.75, batch_mean=True):
return torch.pow(focal_tversky_loss, gamma)


def MAE_loss(y_true, y_pred):
return torch.mean(torch.abs(y_true - y_pred))


# alias
bce_loss = nn.BCELoss()
23 changes: 13 additions & 10 deletions patchgan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ def patchgan_train():
datagen = Dataset(data_paths['images'], data_paths['masks'], size=size, augmentation=augmentation, **dataset_kwargs)
train_datagen, val_datagen = random_split(datagen, train_val_split)

model_params = config['model_params']
gen_filts = model_params['gen_filts']
disc_filts = model_params['disc_filts']
n_disc_layers = model_params['n_disc_layers']
activation = model_params['activation']
use_dropout = model_params.get('use_dropout', True)
final_activation = model_params.get('final_activation', 'sigmoid')

dloader_kwargs = {}
if args.dataloader_workers > 0:
dloader_kwargs['num_workers'] = args.dataloader_workers
Expand All @@ -90,14 +82,25 @@ def patchgan_train():
train_data = DataLoader(train_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs)
val_data = DataLoader(val_datagen, batch_size=args.batch_size, shuffle=True, pin_memory=True, **dloader_kwargs)

model_params = config['model_params']
generator_config = model_params['generator']
discriminator_config = model_params['discriminator']

# create the generator
gen_filts = generator_config['filters']
activation = generator_config['activation']
use_dropout = generator_config.get('use_dropout', True)
final_activation = generator_config.get('final_activation', 'sigmoid')
generator = UNet(in_channels, out_channels, gen_filts, use_dropout=use_dropout, activation=activation, final_act=final_activation).to(device)

# create the discriminator
discriminator = Discriminator(in_channels + out_channels, disc_filts, n_layers=n_disc_layers).to(device)
disc_filts = discriminator_config['filters']
disc_norm = discriminator_config.get('norm', False)
n_disc_layers = discriminator_config['n_layers']
discriminator = Discriminator(in_channels + out_channels, disc_filts, norm=disc_norm, n_layers=n_disc_layers).to(device)

if args.summary:
summary(generator, [1, in_channels, size, size])
summary(generator, [1, in_channels, size, size], depth=4)
summary(discriminator, [1, in_channels + out_channels, size, size])

checkpoint_path = config.get('checkpoint_path', './checkpoints/')
Expand Down
8 changes: 5 additions & 3 deletions patchgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from torch import optim
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from .losses import fc_tversky, bce_loss
from .losses import fc_tversky, bce_loss, MAE_loss
from torch.nn.functional import binary_cross_entropy
from collections import defaultdict

Expand Down Expand Up @@ -78,6 +78,8 @@ def batch(self, x, y, train=False):
else:
weight = torch.ones_like(target_tensor)
gen_loss = binary_cross_entropy(gen_img, target_tensor, weight=weight) * self.seg_alpha
elif self.loss_type == 'MAE':
gen_loss = MAE_loss(gen_img, target_tensor) * self.seg_alpha

gen_loss_disc = bce_loss(disc_fake, labels_real)
gen_loss = gen_loss + gen_loss_disc
Expand Down Expand Up @@ -164,9 +166,9 @@ def train(self, train_data, val_data, epochs, dsc_learning_rate=1.e-3,
self.neptune_config['model/parameters/n_epochs'] = epochs

# create the Adam optimzers
self.gen_optimizer = optim.NAdam(
self.gen_optimizer = optim.Adam(
self.generator.parameters(), lr=gen_lr, betas=(0.9, 0.999))
self.disc_optimizer = optim.NAdam(
self.disc_optimizer = optim.Adam(
self.discriminator.parameters(), lr=dsc_lr, betas=(0.9, 0.999))

# set up the learning rate scheduler with exponential lr decay
Expand Down
5 changes: 3 additions & 2 deletions patchgan/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, use_d
downnorm = norm_layer(output_filt)

enc_sub = OrderedDict([(f'DownConv{layer}', downconv),
(f'DownAct{layer}', activation),
(f'DownNorm{layer}', downnorm),
(f'DownAct{layer}', activation),
])
if use_dropout:
enc_sub = OrderedDict(chain(enc_sub.items(),
Expand Down Expand Up @@ -54,8 +54,9 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, batch
if batch_norm:
upnorm = norm_layer(output_filt)
dec_sub = OrderedDict([(f'UpConv{layer}', upconv),
(f'UpNorm{layer}', upnorm),
(f'UpAct{layer}', activation),
(f'UpNorm{layer}', upnorm)])
])
else:
dec_sub = OrderedDict([(f'UpConv{layer}', upconv),
(f'UpAct{layer}', activation)])
Expand Down