diff --git a/discriminator.py b/discriminator.py index 691fd18..4506b5d 100644 --- a/discriminator.py +++ b/discriminator.py @@ -31,9 +31,14 @@ def disc_block(block_idx, input_, kwidth, nfmaps, bnorm, activation, print('D block {} input shape: {}' ''.format(block_idx, input_.get_shape()), end=' *** ') + bias_init = None + if self.bias_D_conv: + if not reuse: + print('biasing D conv', end=' *** ') + bias_init = tf.constant_initializer(0.) downconv_init = tf.truncated_normal_initializer(stddev=0.02) hi_a = downconv(input_, nfmaps, kwidth=kwidth, pool=pooling, - init=downconv_init) + init=downconv_init, bias_init=bias_init) if not reuse: print('downconved shape: {} ' ''.format(hi_a.get_shape()), end=' *** ') diff --git a/generator.py b/generator.py index 377724f..b2cef4f 100644 --- a/generator.py +++ b/generator.py @@ -198,14 +198,14 @@ def make_z(shape, mean=0., std=1., name='z'): h_i_dim = h_i.get_shape().as_list() out_shape = [h_i_dim[0], h_i_dim[1] * 2, layer_depth] bias_init = None - if segan.bias_deconv: - if is_ref: - print('Biasing deconv in G') - bias_init = tf.constant_initializer(0.) # deconv if segan.deconv_type == 'deconv': if is_ref: print('-- Transposed deconvolution type --') + if segan.bias_deconv: + print('Biasing deconv in G') + if segan.bias_deconv: + bias_init = tf.constant_initializer(0.) h_i_dcv = deconv(h_i, out_shape, kwidth=kwidth, dilation=2, init=tf.truncated_normal_initializer(stddev=0.02), bias_init=bias_init, @@ -213,7 +213,11 @@ def make_z(shape, mean=0., std=1., name='z'): elif segan.deconv_type == 'nn_deconv': if is_ref: print('-- NN interpolated deconvolution type --') - h_i_dcv = nn_deconv(h_i, kwdith=kwidth, dilation=2, + if segan.bias_deconv: + print('Biasing deconv in G') + if segan.bias_deconv: + bias_init = 0. + h_i_dcv = nn_deconv(h_i, kwidth=kwidth, dilation=2, init=tf.truncated_normal_initializer(stddev=0.02), bias_init=bias_init, name='dec_{}'.format(layer_idx)) diff --git a/ops.py b/ops.py index d36e62e..a970985 100644 --- a/ops.py +++ b/ops.py @@ -243,10 +243,10 @@ def repeat_elements(x, rep, axis): 'Typically you need to pass a fully-defined ' '`input_shape` argument to your first layer.') # slices along the repeat axis - splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis) + splits = tf.split(split_dim=axis, num_split=x_shape[axis], value=x) # repeat each slice the given number of reps x_rep = [s for s in splits for _ in range(rep)] - return concatenate(x_rep, axis) + return tf.concat(axis, x_rep) def nn_deconv(x, kwidth=5, dilation=2, init=None, uniform=False, bias_init=None, name='nn_deconv1d'):