-
Notifications
You must be signed in to change notification settings - Fork 150
/
model.py
81 lines (65 loc) · 3.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
# !/usr/bin/env python
'''
By Dabi Ahn. [email protected].
https://www.github.com/andabi
'''
from __future__ import division
import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell
from config import ModelConfig
import os
from utils import shape
import numpy as np
class Model:
def __init__(self, n_rnn_layer=3, hidden_size=256):
# Input, Output
self.x_mixed = tf.placeholder(tf.float32, shape=(None, None, ModelConfig.L_FRAME // 2 + 1), name='x_mixed')
self.y_src1 = tf.placeholder(tf.float32, shape=(None, None, ModelConfig.L_FRAME // 2 + 1), name='y_src1')
self.y_src2 = tf.placeholder(tf.float32, shape=(None, None, ModelConfig.L_FRAME // 2 + 1), name='y_src2')
# Network
self.hidden_size = hidden_size
self.n_layer = n_rnn_layer
self.net = tf.make_template('net', self._net)
self()
def __call__(self):
return self.net()
def _net(self):
# RNN and dense layers
rnn_layer = MultiRNNCell([GRUCell(self.hidden_size) for _ in range(self.n_layer)])
output_rnn, rnn_state = tf.nn.dynamic_rnn(rnn_layer, self.x_mixed, dtype=tf.float32)
input_size = shape(self.x_mixed)[2]
y_hat_src1 = tf.layers.dense(inputs=output_rnn, units=input_size, activation=tf.nn.relu, name='y_hat_src1')
y_hat_src2 = tf.layers.dense(inputs=output_rnn, units=input_size, activation=tf.nn.relu, name='y_hat_src2')
# time-freq masking layer
y_tilde_src1 = y_hat_src1 / (y_hat_src1 + y_hat_src2 + np.finfo(float).eps) * self.x_mixed
y_tilde_src2 = y_hat_src2 / (y_hat_src1 + y_hat_src2 + np.finfo(float).eps) * self.x_mixed
return y_tilde_src1, y_tilde_src2
def loss(self):
pred_y_src1, pred_y_src2 = self()
return tf.reduce_mean(tf.square(self.y_src1 - pred_y_src1) + tf.square(self.y_src2 - pred_y_src2), name='loss')
@staticmethod
# shape = (batch_size, n_freq, n_frames) => (batch_size, n_frames, n_freq)
def spec_to_batch(src):
num_wavs, freq, n_frames = src.shape
# Padding
pad_len = 0
if n_frames % ModelConfig.SEQ_LEN > 0:
pad_len = (ModelConfig.SEQ_LEN - (n_frames % ModelConfig.SEQ_LEN))
pad_width = ((0, 0), (0, 0), (0, pad_len))
padded_src = np.pad(src, pad_width=pad_width, mode='constant', constant_values=0)
assert(padded_src.shape[-1] % ModelConfig.SEQ_LEN == 0)
batch = np.reshape(padded_src.transpose(0, 2, 1), (-1, ModelConfig.SEQ_LEN, freq))
return batch, padded_src
@staticmethod
def batch_to_spec(src, num_wav):
# shape = (batch_size, n_frames, n_freq) => (batch_size, n_freq, n_frames)
batch_size, seq_len, freq = src.shape
src = np.reshape(src, (num_wav, -1, freq))
src = src.transpose(0, 2, 1)
return src
@staticmethod
def load_state(sess, ckpt_path):
ckpt = tf.train.get_checkpoint_state(os.path.dirname(ckpt_path + '/checkpoint'))
if ckpt and ckpt.model_checkpoint_path:
tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)