forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
whisper_export_vocabulary.py
163 lines (141 loc) · 4.72 KB
/
whisper_export_vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from pathlib import Path
from typeguard import check_argument_types
from espnet2.text.whisper_tokenizer import LANGUAGES_CODE_MAPPING
from espnet2.utils.types import str2bool
from espnet.utils.cli_utils import get_commandline_args
dirname = os.path.dirname(__file__)
def export_vocabulary(
output: str,
whisper_model: str,
whisper_language: str = "en",
whisper_task: str = "transcribe",
log_level: str = "INFO",
add_token_file_name: str = "none",
sot_asr: bool = False,
speaker_change_symbol: str = "<sc>",
):
try:
import whisper.tokenizer
except Exception as e:
print("Error: whisper is not properly installed.")
print(
"Please install whisper with: cd ${MAIN_ROOT}/tools && "
"./installers/install_whisper.sh"
)
raise e
assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if output == "-":
fout = sys.stdout
else:
p = Path(output)
p.parent.mkdir(parents=True, exist_ok=True)
fout = p.open("w", encoding="utf-8")
whisper_language = LANGUAGES_CODE_MAPPING.get(whisper_language)
if whisper_language is None:
raise ValueError("language unsupported for Whisper model")
if whisper_task not in ["transcribe", "translate"]:
raise ValueError(f"task: {whisper_task} unsupported for Whisper model")
if whisper_model == "whisper_en":
tokenizer = whisper.tokenizer.get_tokenizer(multilingual=False)
elif whisper_model == "whisper_multilingual":
tokenizer = whisper.tokenizer.get_tokenizer(
multilingual=True, language=whisper_language, task=whisper_task
)
# import pdb;pdb.set_trace()
if add_token_file_name != "none":
_added_tokens = []
with open(add_token_file_name) as f:
lines = f.readlines()
for line in lines:
_added_tokens.append(line.rstrip())
tokenizer.tokenizer.add_tokens(_added_tokens)
else:
raise ValueError("tokenizer unsupported:", whisper_model)
vocab_size = tokenizer.tokenizer.vocab_size + len(
tokenizer.tokenizer.get_added_vocab()
)
if whisper_model == "whisper_en":
vocab_size = vocab_size - 1
for i in range(vocab_size):
# take care of special char for <space>
tkn = tokenizer.tokenizer.convert_ids_to_tokens(i).replace("Ġ", " ")
fout.write(tkn + "\n")
# NOTE (Shih-Lun): extra tokens (for timestamped ASR) not
# stored in the wrapped tokenizer
full_vocab_size = 51865 if whisper_model == "whisper_multilingual" else 51864
for i in range(full_vocab_size - vocab_size):
fout.write(f"<|{i*0.02:.2f}|>" + "\n")
if sot_asr:
full_vocab_size += 1
fout.write(speaker_change_symbol + "\n")
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Export Whisper vocabulary",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument(
"--output", "-o", required=True, help="Output text. - indicates sys.stdout"
)
parser.add_argument(
"--whisper_model",
type=str,
required=True,
help="Whisper model type",
)
parser.add_argument(
"--add_token_file_name",
type=str,
default="none",
help="File name for added tokens",
)
parser.add_argument(
"--whisper_language",
type=str,
default="en",
help="Language for Whisper multilingual tokenizer",
)
parser.add_argument(
"--whisper_task",
type=str,
default="transcribe",
help="Task for Whisper multilingual tokenizer",
)
parser.add_argument(
"--sot_asr",
type=str2bool,
default=False,
required=False,
help="Whether SOT-style training is used in Whisper",
)
parser.add_argument(
"--speaker_change_symbol",
type=str,
default="<sc>",
required=False,
help="Whether SOT-style training is used in Whisper",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
export_vocabulary(**kwargs)
if __name__ == "__main__":
main()