-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
39 lines (29 loc) · 830 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 3rd party:
import fire
import tensorflow as tf
# different category:
from style_vae.model import StyleVae, Config
# same category:
from style_vae.train.style_vae_trainer import StyleVaeTrainer
from style_vae.train.vae_trainer_config import VaeTrainerConfig
def train(load: bool):
trainer = _build_trainer()
if load:
trainer.load()
trainer.train()
def _build_trainer() -> StyleVaeTrainer:
# model
model_config = Config()
print(model_config)
model = StyleVae(model_config)
# trainer
trainer_config = VaeTrainerConfig()
print(trainer_config)
sess = tf.Session()
trainer = StyleVaeTrainer(model, trainer_config, sess)
# init
init_op = tf.global_variables_initializer()
sess.run(init_op)
return trainer
if __name__ == '__main__':
fire.Fire(train)