Skip to content

Commit

Permalink
Merge branch 'feat/preemph'
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Apr 23, 2017
2 parents f010bdb + 15ee357 commit ea12768
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 24 deletions.
21 changes: 20 additions & 1 deletion data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,22 @@
import numpy as np


def read_and_decode(filename_queue, canvas_size):
def pre_emph(x, coeff=0.95):
x0 = tf.reshape(x[0], [1,])
diff = x[1:] - coeff * x[:-1]
concat = tf.concat(0, [x0, diff])
return concat

def de_emph(y, coeff=0.95):
if coeff <= 0:
return y
x = np.zeros(y.shape[0], dtype=np.float32)
x[0] = y[0]
for n in range(1, y.shape[0], 1):
x[n] = coeff * x[n - 1] + y[n]
return x

def read_and_decode(filename_queue, canvas_size, preemph=0.):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
Expand All @@ -20,4 +35,8 @@ def read_and_decode(filename_queue, canvas_size):
noisy.set_shape(canvas_size)
noisy = (2./65535.) * tf.cast((noisy - 32767), tf.float32) + 1.

if preemph > 0:
wave = tf.cast(pre_emph(wave, preemph), tf.float32)
noisy = tf.cast(pre_emph(noisy, preemph), tf.float32)

return wave, noisy
30 changes: 27 additions & 3 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,14 @@ def make_z(shape, mean=0., std=1., name='z'):
# dec ~ [8x2048, 16x1024, 32x512, 64x512, 8x256, 256x256, 512x128, 1024x128, 2048x64, 4096x64, 8192x32, 16384x1]
#FIRST ENCODER
for layer_idx, layer_depth in enumerate(segan.g_enc_depths):
bias_init = None
if segan.bias_downconv:
if is_ref:
print('Biasing downconv in G')
bias_init = tf.constant_initializer(0.)
h_i_dwn = downconv(h_i, layer_depth, kwidth=kwidth,
init=tf.truncated_normal_initializer(stddev=0.02),
bias_init=bias_init,
name='enc_{}'.format(layer_idx))
if is_ref:
print('Downconv {} -> {}'.format(h_i.get_shape(),
Expand Down Expand Up @@ -191,10 +197,28 @@ def make_z(shape, mean=0., std=1., name='z'):
for layer_idx, layer_depth in enumerate(g_dec_depths):
h_i_dim = h_i.get_shape().as_list()
out_shape = [h_i_dim[0], h_i_dim[1] * 2, layer_depth]
bias_init = None
if segan.bias_deconv:
if is_ref:
print('Biasing deconv in G')
bias_init = tf.constant_initializer(0.)
# deconv
h_i_dcv = deconv(h_i, out_shape, kwidth=kwidth, dilation=2,
init=tf.truncated_normal_initializer(stddev=0.02),
name='dec_{}'.format(layer_idx))
if segan.deconv_type == 'deconv':
if is_ref:
print('-- Transposed deconvolution type --')
h_i_dcv = deconv(h_i, out_shape, kwidth=kwidth, dilation=2,
init=tf.truncated_normal_initializer(stddev=0.02),
bias_init=bias_init,
name='dec_{}'.format(layer_idx))
elif segan.deconv_type == 'nn_deconv':
if is_ref:
print('-- NN interpolated deconvolution type --')
h_i_dcv = nn_deconv(h_i, kwdith=kwidth, dilation=2,
init=tf.truncated_normal_initializer(stddev=0.02),
bias_init=bias_init,
name='dec_{}'.format(layer_idx))
else:
raise ValueError('Unknown deconv type {}'.format(segan.deconv_type))
if is_ref:
print('Deconv {} -> {}'.format(h_i.get_shape(),
h_i_dcv.get_shape()))
Expand Down
18 changes: 18 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from tensorflow.python.client import device_lib
from scipy.io import wavfile
from data_loader import pre_emph


devices = device_lib.list_local_devices()
Expand All @@ -19,6 +20,12 @@
"removed (Def: 5).")
flags.DEFINE_integer("l1_remove_epoch", 150, "Epoch where L1 in G is "
"removed (Def: 150).")
flags.DEFINE_boolean("bias_deconv", False,
"Flag to specify if we bias deconvs (Def: False)")
flags.DEFINE_boolean("bias_downconv", False,
"flag to specify if we bias downconvs (def: false)")
flags.DEFINE_boolean("bias_D_conv", False,
"flag to specify if we bias D_convs (def: false)")
# TODO: noise decay is under check
flags.DEFINE_float("denoise_lbound", 0.01, "Min noise std to be still alive (Def: 0.001)")
flags.DEFINE_float("noise_decay", 0.7, "Decay rate of noise std (Def: 0.7)")
Expand All @@ -32,10 +39,13 @@
").")
flags.DEFINE_string("g_nl", "leaky", "Type of nonlinearity in G: leaky or prelu. (Def: leaky).")
flags.DEFINE_string("model", "gan", "Type of model to train: gan or ae. (Def: gan).")
flags.DEFINE_string("deconv_type", "deconv", "Type of deconv method: deconv or "
"nn_deconv (Def: deconv).")
flags.DEFINE_string("g_type", "ae", "Type of G to use: ae or dwave. (Def: ae).")
flags.DEFINE_float("g_learning_rate", 0.0002, "G learning_rate (Def: 0.0002)")
flags.DEFINE_float("d_learning_rate", 0.0002, "D learning_rate (Def: 0.0002)")
flags.DEFINE_float("beta_1", 0.5, "Adam beta 1 (Def: 0.5)")
flags.DEFINE_float("preemph", 0.95, "Pre-emph factor (Def: 0.95)")
flags.DEFINE_string("synthesis_path", "dwavegan_samples", "Path to save output"
" generated samples."
" (Def: dwavegan_sam"
Expand All @@ -48,6 +58,10 @@
flags.DEFINE_string("weights", None, "Weights file")
FLAGS = flags.FLAGS

def pre_emph_test(coeff, canvas_size):
x_ = tf.placeholder(tf.float32, shape=[canvas_size,])
x_preemph = pre_emph(x_, coeff)
return x_, x_preemph

def main(_):
print('Parsed arguments: ', FLAGS.__flags)
Expand Down Expand Up @@ -90,6 +104,10 @@ def main(_):
if fm != 16000:
raise ValueError('16kHz required! Test file is different')
wave = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1.
if FLAGS.preemph > 0:
print('preemph test wave with {}'.format(FLAGS.preemph))
x_pholder, preemph_op = pre_emph_test(FLAGS.preemph, wave.shape[0])
wave = sess.run(preemph_op, feed_dict={x_pholder:wave})
print('test wave shape: ', wave.shape)
print('test wave min:{} max:{}'.format(np.min(wave), np.max(wave)))
c_wave = se_model.clean(wave)
Expand Down
79 changes: 60 additions & 19 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from generator import *
from discriminator import *
import numpy as np
from data_loader import read_and_decode
from data_loader import read_and_decode, de_emph
from bnorm import VBN
from ops import *
import timeit
Expand Down Expand Up @@ -67,6 +67,12 @@ def __init__(self, sess, args, devices, infer=False, name='SEGAN'):
self.devices = devices
self.z_dim = args.z_dim
self.z_depth = args.z_depth
# type of deconv
self.deconv_type = deconv_type
# specify if use biases or not
self.bias_downconv = args.bias_downconv
self.bias_deconv = args.bias_deconv
self.bias_D_conv = args.bias_D_conv
# clip D values
self.d_clip_weights = False
# apply VBN or regular BN?
Expand All @@ -75,6 +81,12 @@ def __init__(self, sess, args, devices, infer=False, name='SEGAN'):
# num of updates to be applied to D before G
# this is k in original GAN paper (https://arxiv.org/abs/1406.2661)
self.disc_updates = 1
# set preemph factor
self.preemph = args.preemph
if self.preemph > 0:
print('*** Applying pre-emphasis of {} ***'.format(self.preemph))
else:
print('--- No pre-emphasis applied ---')
# canvas size
self.canvas_size = args.canvas_size
self.deactivated_noise = False
Expand Down Expand Up @@ -138,7 +150,8 @@ def build_model_single_gpu(self, gpu_idx):
# create the nodes to load for input pipeline
filename_queue = tf.train.string_input_producer([self.e2e_dataset])
self.get_wav, self.get_noisy = read_and_decode(filename_queue,
2 ** 14)
self.canvas_size,
self.preemph)
# load the data to input pipeline
wavbatch, \
noisybatch = tf.train.shuffle_batch([self.get_wav,
Expand Down Expand Up @@ -454,17 +467,43 @@ def train(self, config, devices):
swaves = sample_wav
sample_dif = sample_wav - sample_noisy
for m in range(min(20, canvas_w.shape[0])):
print('w{} max: {} min: {}'.format(m, np.max(canvas_w[m]), np.min(canvas_w[m])))
wavfile.write(os.path.join(save_path, 'sample_{}-{}.wav'.format(counter, m)), 16e3, canvas_w[m])
if not os.path.exists(os.path.join(save_path, 'gtruth_{}.wav'.format(m))):
wavfile.write(os.path.join(save_path, 'gtruth_{}.wav'.format(m)), 16e3, swaves[m])
wavfile.write(os.path.join(save_path, 'noisy_{}.wav'.format(m)), 16e3, sample_noisy[m])
wavfile.write(os.path.join(save_path, 'dif_{}.wav'.format(m)), 16e3, sample_dif[m])
np.savetxt(os.path.join(save_path, 'd_rl_losses.txt'), d_rl_losses)
np.savetxt(os.path.join(save_path, 'd_fk_losses.txt'), d_fk_losses)
#np.savetxt(os.path.join(save_path, 'd_nfk_losses.txt'), d_nfk_losses)
np.savetxt(os.path.join(save_path, 'g_adv_losses.txt'), g_adv_losses)
np.savetxt(os.path.join(save_path, 'g_l1_losses.txt'), g_l1_losses)
print('w{} max: {} min: {}'.format(m,
np.max(canvas_w[m]),
np.min(canvas_w[m])))
wavfile.write(os.path.join(save_path,
'sample_{}-'
'{}.wav'.format(counter, m)),
16e3,
de_emph(canvas_w[m],
self.preemph))
m_gtruth_path = os.path.join(save_path, 'gtruth_{}.'
'wav'.format(m))
if not os.path.exists(m_gtruth_path):
wavfile.write(os.path.join(save_path,
'gtruth_{}.'
'wav'.format(m)),
16e3,
de_emph(swaves[m],
self.preemph))
wavfile.write(os.path.join(save_path,
'noisy_{}.'
'wav'.format(m)),
16e3,
de_emph(sample_noisy[m],
self.preemph))
wavfile.write(os.path.join(save_path,
'dif_{}.wav'.format(m)),
16e3,
de_emph(sample_dif[m],
self.preemph))
np.savetxt(os.path.join(save_path, 'd_rl_losses.txt'),
d_rl_losses)
np.savetxt(os.path.join(save_path, 'd_fk_losses.txt'),
d_fk_losses)
np.savetxt(os.path.join(save_path, 'g_adv_losses.txt'),
g_adv_losses)
np.savetxt(os.path.join(save_path, 'g_l1_losses.txt'),
g_l1_losses)

if batch_idx >= num_batches:
curr_epoch += 1
Expand Down Expand Up @@ -508,14 +547,14 @@ def clean(self, x):
x: numpy array containing the normalized noisy waveform
"""
c_res = None
for beg_i in range(0, x.shape[0], 2 ** 14):
if x.shape[0] - beg_i < 2 ** 14:
for beg_i in range(0, x.shape[0], self.canvas_size):
if x.shape[0] - beg_i < self.canvas_size:
length = x.shape[0] - beg_i
pad = (2 ** 14) - length
pad = (self.canvas_size) - length
else:
length = 2 ** 14
length = self.canvas_size
pad = 0
x_ = np.zeros((self.batch_size, 2 ** 14))
x_ = np.zeros((self.batch_size, self.canvas_size))
if pad > 0:
x_[0] = np.concatenate((x[beg_i:beg_i + length], np.zeros(pad)))
else:
Expand All @@ -524,7 +563,7 @@ def clean(self, x):
fdict = {self.gtruth_noisy[0]:x_}
canvas_w = self.sess.run(self.Gs[0],
feed_dict=fdict)[0]
canvas_w = canvas_w.reshape((2 ** 14))
canvas_w = canvas_w.reshape((self.canvas_size))
print('canvas w shape: ', canvas_w.shape)
if pad > 0:
print('Removing padding of {} samples'.format(pad))
Expand All @@ -534,6 +573,8 @@ def clean(self, x):
c_res = canvas_w
else:
c_res = np.concatenate((c_res, canvas_w))
# deemphasize
c_res = de_emph(c_res, self.preemph)
return c_res


Expand Down
39 changes: 38 additions & 1 deletion ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,44 @@ def residual_block(input_, dilation, kwidth, num_kernels=1,
return res


# Code from keras backend
# https://github.com/fchollet/keras/blob/master/keras/backend/tensorflow_backend.py
def repeat_elements(x, rep, axis):
"""Repeats the elements of a tensor along an axis, like `np.repeat`.
If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
will have shape `(s1, s2 * rep, s3)`.
# Arguments
x: Tensor or variable.
rep: Python integer, number of times to repeat.
axis: Axis along which to repeat.
# Raises
ValueError: In case `x.shape[axis]` is undefined.
# Returns
A tensor.
"""
x_shape = x.get_shape().as_list()
if x_shape[axis] is None:
raise ValueError('Axis ' + str(axis) + ' of input tensor '
'should have a defined dimension, but is None. '
'Full tensor shape: ' + str(tuple(x_shape)) + '. '
'Typically you need to pass a fully-defined '
'`input_shape` argument to your first layer.')
# slices along the repeat axis
splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
# repeat each slice the given number of reps
x_rep = [s for s in splits for _ in range(rep)]
return concatenate(x_rep, axis)

def nn_deconv(x, kwidth=5, dilation=2, init=None, uniform=False,
bias_init=None, name='nn_deconv1d'):
# first compute nearest neighbour interpolated x
interp_x = repeat_elements(x, dilation, 1)
# run a convolution over the interpolated fmap
dec = conv1d(interp_x, kwidth=5, num_kernels=1, init=init, uniform=uniform,
bias_init=bias_init, name=name, padding='SAME')
return dec


def deconv(x, output_shape, kwidth=5, dilation=2, init=None, uniform=False,
bias_init=None, name='deconv1d'):
input_shape = x.get_shape()
Expand Down Expand Up @@ -255,7 +293,6 @@ def deconv(x, output_shape, kwidth=5, dilation=2, init=None, uniform=False,
deconv = tf.reshape(deconv, output_shape)
return deconv


def conv2d(input_, output_dim, k_h, k_w, stddev=0.05, name="conv2d", with_w=False):
with tf.variable_scope(name):
w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
Expand Down

0 comments on commit ea12768

Please sign in to comment.