Skip to content

Commit

Permalink
cleaned up formatting. moving weights_init to unet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanakumars committed Nov 21, 2023
1 parent e31cab6 commit 245f0b3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
3 changes: 3 additions & 0 deletions patchgan/disc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import nn
from .transfer import Transferable
from .unet import weights_init


class Discriminator(nn.Module, Transferable):
Expand Down Expand Up @@ -46,6 +47,8 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.Insta
stride=1, padding=padw), nn.Sigmoid()]
self.model = nn.Sequential(*sequence)

self.apply(weights_init)

def forward(self, input):
"""Standard forward."""
return self.model(input)
29 changes: 25 additions & 4 deletions patchgan/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, use_d
(f'DownAct{layer}', activation),
])
if use_dropout:
enc_sub = OrderedDict(chain(enc_sub.items(),
[(f'DownDropout{layer}', nn.Dropout(0.2))]))
enc_sub = OrderedDict(chain(enc_sub.items(), [(f'DownDropout{layer}', nn.Dropout(0.2))]))

self.model = nn.Sequential(enc_sub)

Expand Down Expand Up @@ -61,8 +60,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, batch
dec_sub = OrderedDict([(f'UpConv{layer}', upconv),
(f'UpAct{layer}', activation)])
if use_dropout:
dec_sub = OrderedDict(chain(dec_sub.items(),
[(f'UpDropout{layer}', nn.Dropout(0.2))]))
dec_sub = OrderedDict(chain(dec_sub.items(), [(f'UpDropout{layer}', nn.Dropout(0.2))]))

self.model = nn.Sequential(dec_sub)

Expand Down Expand Up @@ -109,6 +107,8 @@ def __init__(self, input_nc, output_nc, nf=64,
self.encoder = nn.ModuleList(encoder_layers)
self.decoder = nn.ModuleList(decoder_layers)

self.apply(weights_init)

def forward(self, x, return_hidden=False):
xencs = []

Expand All @@ -132,3 +132,24 @@ def forward(self, x, return_hidden=False):
return x, hidden
else:
return x


# custom weights initialization called on generator and discriminator
# scaling here means std
def weights_init(net, init_type='normal', scaling=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv')) != -1:
torch.nn.init.xavier_uniform_(m.weight.data)
# BatchNorm Layer's weight is not a matrix; only normal distribution applies.
elif classname.find('InstanceNorm') != -1:
torch.nn.init.xavier_uniform_(m.weight.data, 1.0)
torch.nn.init.constant_(m.bias.data, 0.0)

0 comments on commit 245f0b3

Please sign in to comment.