diff --git a/patchgan/disc.py b/patchgan/disc.py index 83f48ae..c00b1b9 100644 --- a/patchgan/disc.py +++ b/patchgan/disc.py @@ -1,5 +1,6 @@ from torch import nn from .transfer import Transferable +from .unet import weights_init class Discriminator(nn.Module, Transferable): @@ -46,6 +47,8 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm=False, norm_layer=nn.Insta stride=1, padding=padw), nn.Sigmoid()] self.model = nn.Sequential(*sequence) + self.apply(weights_init) + def forward(self, input): """Standard forward.""" return self.model(input) diff --git a/patchgan/unet.py b/patchgan/unet.py index 4805ce1..7df92c9 100755 --- a/patchgan/unet.py +++ b/patchgan/unet.py @@ -24,8 +24,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, use_d (f'DownAct{layer}', activation), ]) if use_dropout: - enc_sub = OrderedDict(chain(enc_sub.items(), - [(f'DownDropout{layer}', nn.Dropout(0.2))])) + enc_sub = OrderedDict(chain(enc_sub.items(), [(f'DownDropout{layer}', nn.Dropout(0.2))])) self.model = nn.Sequential(enc_sub) @@ -61,8 +60,7 @@ def __init__(self, input_filt, output_filt, activation, norm_layer, layer, batch dec_sub = OrderedDict([(f'UpConv{layer}', upconv), (f'UpAct{layer}', activation)]) if use_dropout: - dec_sub = OrderedDict(chain(dec_sub.items(), - [(f'UpDropout{layer}', nn.Dropout(0.2))])) + dec_sub = OrderedDict(chain(dec_sub.items(), [(f'UpDropout{layer}', nn.Dropout(0.2))])) self.model = nn.Sequential(dec_sub) @@ -109,6 +107,8 @@ def __init__(self, input_nc, output_nc, nf=64, self.encoder = nn.ModuleList(encoder_layers) self.decoder = nn.ModuleList(decoder_layers) + self.apply(weights_init) + def forward(self, x, return_hidden=False): xencs = [] @@ -132,3 +132,24 @@ def forward(self, x, return_hidden=False): return x, hidden else: return x + + +# custom weights initialization called on generator and discriminator +# scaling here means std +def weights_init(net, init_type='normal', scaling=0.02): + """Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv')) != -1: + torch.nn.init.xavier_uniform_(m.weight.data) + # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + elif classname.find('InstanceNorm') != -1: + torch.nn.init.xavier_uniform_(m.weight.data, 1.0) + torch.nn.init.constant_(m.bias.data, 0.0)