Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converting trainer to PyTorch Lightning module #9

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ dmypy.json
*.pth
*.npz
*.png
*.ckpt
lightning_logs/
6 changes: 2 additions & 4 deletions patchgan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from .unet import UNet
from .disc import Discriminator
from .trainer import Trainer
from .patchgan import PatchGAN
from .version import __version__

__all__ = [
'UNet', 'Discriminator', 'Trainer', '__version__'
'PatchGAN', '__version__'
]
98 changes: 50 additions & 48 deletions patchgan/unet.py → patchgan/conv_layers.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from torch import nn
from collections import OrderedDict
from itertools import chain
from .transfer import Transferable


class DownSampleBlock(nn.Module):
Expand All @@ -24,8 +23,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, use_d
(f'DownAct{layer}', activation),
])
if use_dropout:
enc_sub = OrderedDict(chain(enc_sub.items(),
[(f'DownDropout{layer}', nn.Dropout(0.2))]))
enc_sub = OrderedDict(chain(enc_sub.items(), [(f'DownDropout{layer}', nn.Dropout(0.2))]))

self.model = nn.Sequential(enc_sub)

Expand Down Expand Up @@ -61,8 +59,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, batch
dec_sub = OrderedDict([(f'UpConv{layer}', upconv),
(f'UpAct{layer}', activation)])
if use_dropout:
dec_sub = OrderedDict(chain(dec_sub.items(),
[(f'UpDropout{layer}', nn.Dropout(0.2))]))
dec_sub = OrderedDict(chain(dec_sub.items(), [(f'UpDropout{layer}', nn.Dropout(0.2))]))

self.model = nn.Sequential(dec_sub)

Expand All @@ -72,63 +69,68 @@ def forward(self, x):
return x


class UNet(nn.Module, Transferable):
def __init__(self, input_nc, output_nc, nf=64,
norm_layer=nn.InstanceNorm2d, use_dropout=False,
activation='tanh', final_act='softmax'):
super(UNet, self).__init__()

# custom weights initialization called on generator and discriminator
# scaling here means std
def weights_init(net, init_type='normal', scaling=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv')) != -1:
torch.nn.init.xavier_uniform_(m.weight.data)
# BatchNorm Layer's weight is not a matrix; only normal distribution applies.
elif classname.find('InstanceNorm') != -1:
torch.nn.init.xavier_uniform_(m.weight.data, 1.0)
torch.nn.init.constant_(m.bias.data, 0.0)


class Encoder(nn.ModuleList):
def __init__(self, input_channels: int, gen_filts: int, gen_activation: str, use_gen_dropout: bool):
kernel_size = 4
padding = 1

conv_filts = [nf, nf * 2, nf * 4, nf * 8, nf * 8, nf * 8, nf * 8]
conv_filts = [gen_filts, gen_filts * 2, gen_filts * 4, gen_filts * 8, gen_filts * 8, gen_filts * 8, gen_filts * 8]

encoder_layers = []

prev_filt = input_nc
prev_filt = input_channels
for i, filt in enumerate(conv_filts):
encoder_layers.append(DownSampleBlock(prev_filt, filt, activation, norm_layer, layer=i,
use_dropout=use_dropout, kernel_size=kernel_size, stride=2,
padding=padding, bias=False))
encoder_layers.append(DownSampleBlock(prev_filt, filt, gen_activation, nn.InstanceNorm2d, layer=i,
use_dropout=use_gen_dropout, kernel_size=kernel_size, stride=2,
padding=padding))
prev_filt = filt

decoder_layers = []
for i, filt in enumerate(conv_filts[:-1][::-1]):
if i == 0:
decoder_layers.append(UpSampleBlock(prev_filt, filt, activation, norm_layer, layer=i, batch_norm=False,
kernel_size=kernel_size, stride=2, padding=padding, bias=False))
else:
decoder_layers.append(UpSampleBlock(prev_filt * 2, filt, activation, norm_layer, layer=i, use_dropout=use_dropout,
batch_norm=True, kernel_size=kernel_size, stride=2, padding=padding, bias=False))

prev_filt = filt

decoder_layers.append(UpSampleBlock(nf * 2, output_nc, final_act, norm_layer, layer=i + 1, batch_norm=False,
kernel_size=kernel_size, stride=2, padding=padding, bias=False))

self.encoder = nn.ModuleList(encoder_layers)
self.decoder = nn.ModuleList(decoder_layers)
super().__init__(encoder_layers)
self.apply(weights_init)

def forward(self, x, return_hidden=False):
xencs = []

for i, layer in enumerate(self.encoder):
x = layer(x)
xencs.append(x)

hidden = xencs[-1]
class Decoder(nn.ModuleList):
def __init__(self, output_channels: int, gen_filts: int, gen_activation: str, final_activation: str, use_gen_dropout: bool):
kernel_size = 4
padding = 1

xencs = xencs[::-1]
conv_filts = [gen_filts, gen_filts * 2, gen_filts * 4, gen_filts * 8, gen_filts * 8, gen_filts * 8, gen_filts * 8]

for i, layer in enumerate(self.decoder):
prev_filt = conv_filts[-1]
decoder_layers = []
for i, filt in enumerate(conv_filts[:-1][::-1]):
if i == 0:
xinp = hidden
decoder_layers.append(UpSampleBlock(prev_filt, filt, gen_activation, nn.InstanceNorm2d, layer=i, batch_norm=False,
kernel_size=kernel_size, stride=2, padding=padding))
else:
xinp = torch.cat([x, xencs[i]], dim=1)
decoder_layers.append(UpSampleBlock(prev_filt * 2, filt, gen_activation, nn.InstanceNorm2d, layer=i, use_dropout=use_gen_dropout,
batch_norm=True, kernel_size=kernel_size, stride=2, padding=padding))

prev_filt = filt

x = layer(xinp)
decoder_layers.append(UpSampleBlock(gen_filts * 2, output_channels, final_activation, nn.InstanceNorm2d, layer=i + 1, batch_norm=False,
kernel_size=kernel_size, stride=2, padding=padding))

if return_hidden:
return x, hidden
else:
return x
super().__init__(decoder_layers)
self.apply(weights_init)
10 changes: 4 additions & 6 deletions patchgan/disc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from torch import nn
from .transfer import Transferable
from .conv_layers import weights_init


class Discriminator(nn.Module, Transferable):
class Discriminator(nn.Sequential, Transferable):
"""Defines a PatchGAN discriminator"""

def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.InstanceNorm2d):
Expand All @@ -13,7 +14,6 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.Insta
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(Discriminator, self).__init__()
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
Expand Down Expand Up @@ -44,8 +44,6 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.Insta
# output 1 channel prediction map
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw,
stride=1, padding=padw), nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
super().__init__(*sequence)

def forward(self, input):
"""Standard forward."""
return self.model(input)
self.apply(weights_init)
Loading
Loading