-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
94 lines (76 loc) · 3.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os,sys
import json
sys.path.append("./")
import tensorflow as tf
import numpy as np
#from model.vae import VAWGAN
from analyzer import read, Tanhize
from analyzer import read_whole_features, SPEAKERS, pw2wav
from datetime import datetime
from util.wrapper import get_default_output, convert_f0, nh_to_nchw
from util.wrapper import save, validate_log_dirs #, load, configure_gpu_settings, restore_global_step
#from trainer.vae import GANTrainer
from importlib import import_module
args = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('src', 'SF1', 'source speaker [SF1 - SM2]')
tf.app.flags.DEFINE_string('trg', 'TM3', 'target speaker [SF1 - TM3]')
tf.app.flags.DEFINE_string('output_dir', './logdir', 'root of output dir')
tf.app.flags.DEFINE_string('file_pattern', './dataset/vcc2016/bin/Testing Set/{}/*.bin', 'file pattern')
tf.app.flags.DEFINE_string(
'logdir_root', None, 'root of log dir')
tf.app.flags.DEFINE_string(
'logdir', None, 'log dir')
tf.app.flags.DEFINE_string(
'restore_from', None, 'restore from dir (not from *.ckpt)')
tf.app.flags.DEFINE_string('gpu_cfg', None, 'GPU configuration')
tf.app.flags.DEFINE_integer('summary_freq', 1000, 'Update summary')
tf.app.flags.DEFINE_string(
'ckpt', None, 'specify the ckpt in restore_from (if there are multiple ckpts)') # TODO
tf.app.flags.DEFINE_string(
'architecture', 'architecture-vawgan-vcc2016.json', 'network architecture')
tf.app.flags.DEFINE_string('model_module', 'model.vae', 'Model module')
tf.app.flags.DEFINE_string('model', None, 'Model: ConvVAE, VAWGAN')
tf.app.flags.DEFINE_string('trainer_module', 'trainer.vae', 'Trainer module')
tf.app.flags.DEFINE_string('trainer', None, 'Trainer: VAETrainer, VAWGANTrainer')
tf.app.flags.DEFINE_string('load_model', None, 'load checkpoint')
if args.model is None or args.trainer is None:
raise ValueError(
'\n Both `model` and `trainer` should be assigned.' +\
'\n Use `python main.py --help` to see applicable options.'
)
module = import_module(args.model_module, package=None)
MODEL = getattr(module, args.model)
#MODEL = VAWGAN
module = import_module(args.trainer_module, package=None)
TRAINER = getattr(module, args.trainer)
#TRAINER = GANTrainer
def main():
''' NOTE: The input is rescaled to [-1, 1] '''
dirs = validate_log_dirs(args)
if args.restore_from is None:
tf.gfile.MakeDirs(dirs['logdir'])
with open(args.architecture) as f:
arch = json.load(f)
with open(os.path.join(args.architecture), 'w') as f:
json.dump(arch, f, indent=4)
normalizer = Tanhize(
xmax=np.fromfile('./etc/xmax.npf'),
xmin=np.fromfile('./etc/xmin.npf'),
)
image, label = read(
file_pattern=arch['training']['datadir'],
batch_size=arch['training']['batch_size'],
capacity=2048,
min_after_dequeue=1024,
normalizer=normalizer,
) #image format NHWC
print "image shape:",image
print "label shape:",label
machine = MODEL(arch)#, args)#, True, False)
loss = machine.loss(image, label)#, True)
#sample = machine.sample()#np.asarray([SPEAKERS.index(args.trg)]))
# sample,
trainer = TRAINER(loss, arch, args, dirs)
trainer.train(nIter=arch['training']['max_iter'], n_unroll=arch['training']['n_unroll'], machine=machine)
if __name__ == '__main__':
main()