-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
120 lines (87 loc) · 3.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
import os
import random
import shutil
from glob import glob
import click
import dill
import numpy as np
import pandas as pd
from natsort import natsorted
from constants import TRAIN_TEST_RATIO
logger = logging.getLogger(__name__)
def find_files(directory, ext='wav'):
return sorted(glob(directory + f'/**/*.{ext}', recursive=True))
def init_pandas():
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
def create_new_empty_dir(directory: str):
if os.path.exists(directory):
shutil.rmtree(directory)
os.makedirs(directory)
def ensure_dir_for_filename(filename: str):
ensures_dir(os.path.dirname(filename))
def ensures_dir(directory: str):
if len(directory) > 0 and not os.path.exists(directory):
os.makedirs(directory)
class ClickType:
@staticmethod
def input_file(writable=False):
return click.Path(exists=True, file_okay=True, dir_okay=False,
writable=writable, readable=True, resolve_path=True)
@staticmethod
def input_dir(writable=False):
return click.Path(exists=True, file_okay=False, dir_okay=True,
writable=writable, readable=True, resolve_path=True)
@staticmethod
def output_file():
return click.Path(exists=False, file_okay=True, dir_okay=False,
writable=True, readable=True, resolve_path=True)
@staticmethod
def output_dir():
return click.Path(exists=False, file_okay=False, dir_okay=True,
writable=True, readable=True, resolve_path=True)
def parallel_function(f, sequence, num_threads=None):
from multiprocessing import Pool
pool = Pool(processes=num_threads)
result = pool.map(f, sequence)
cleaned = [x for x in result if x is not None]
pool.close()
pool.join()
return cleaned
def load_best_checkpoint(checkpoint_dir):
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5')))
if len(checkpoints) != 0:
return checkpoints[-1]
return None
def delete_older_checkpoints(checkpoint_dir, max_to_keep=5):
assert max_to_keep > 0
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5')))
checkpoints_to_keep = checkpoints[-max_to_keep:]
for checkpoint in checkpoints:
if checkpoint not in checkpoints_to_keep:
os.remove(checkpoint)
def enable_deterministic():
print('Deterministic mode enabled.')
np.random.seed(123)
random.seed(123)
def load_pickle(file):
if not os.path.exists(file):
return None
logger.info(f'Loading PKL file: {file}.')
with open(file, 'rb') as r:
return dill.load(r)
def load_npy(file):
if not os.path.exists(file):
return None
logger.info(f'Loading NPY file: {file}.')
return np.load(file)
def train_test_sp_to_utt(audio, is_test):
sp_to_utt = {}
for speaker_id, utterances in audio.speakers_to_utterances.items():
utterances_files = sorted(utterances.values())
train_test_sep = int(len(utterances_files) * TRAIN_TEST_RATIO)
sp_to_utt[speaker_id] = utterances_files[train_test_sep:] if is_test else utterances_files[:train_test_sep]
return sp_to_utt