-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathoptuna_search.py
59 lines (43 loc) · 1.64 KB
/
optuna_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from absl import app, flags
from avsr import utils
import optuna
from avsr import AVSR
import logging
FLAGS = flags.FLAGS
def main(argv):
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_id
FLAGS.architecture = 'transformer'
records_path = './data/'
labels_train_record = records_path + 'characters_train_success.tfrecord'
labels_test_record = records_path + 'characters_test_success.tfrecord'
audio_train_records = (
records_path + 'logmel_train_success_clean.tfrecord',
)
audio_test_records = (
records_path + 'logmel_test_success_clean.tfrecord',
)
def experiment(trial):
FLAGS.wordloss_weight = trial.suggest_uniform("Wloss_weight", 0.0, 1.0)
inst = AVSR(
audio_train_record=audio_train_records[0],
audio_test_record=audio_test_records[0],
labels_train_record=labels_train_record,
labels_test_record=labels_test_record,
)
acc = inst.optuna_train(target_epoch=20, learning_rate=0.001)
return acc
init_logging()
study = optuna.create_study(study_name='wloss')
logging.info('Start of Study')
study.optimize(experiment, n_trials=100)
logging.info('End of Study')
def init_logging():
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Setup the root logger.
logger.addHandler(logging.FileHandler("./logs/" + FLAGS.logfile + ".log", mode="w"))
optuna.logging.enable_propagation() # Propagate logs to the root logger.
optuna.logging.disable_default_handler() # Stop showing logs in sys.stderr.
if __name__ == '__main__':
utils.avsr_flags()
app.run(main)