forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
whisper.py
352 lines (310 loc) · 18 KB
/
whisper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
# thanks to https://github.com/openai/whisper for a good chunk of MIT licensed code
import sys
import pathlib
import base64
import multiprocessing
import numpy as np
from typing import Optional, Union, Literal, List
from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv, DEBUG, CI, fetch
import tinygrad.nn as nn
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
import itertools
import librosa
class MultiHeadAttention:
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.kv_caching = kv_caching
self.max_self_attn_cache_len = max_self_attn_cache_len
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
if self.kv_caching == 'cross':
if xa is not None:
k, v = self.key(xa), self.value(xa)
if not hasattr(self, 'cache_k'):
self.cache_k, self.cache_v = k, v
else:
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
self.cache_k.assign(k+1-1).realize()
self.cache_v.assign(v+1-1).realize()
else:
k, v = self.cache_k, self.cache_v
else:
k, v = self.key(x), self.value(x)
if self.kv_caching == 'self':
if not hasattr(self, 'cache_k'):
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
padding = self.max_self_attn_cache_len-len-x.shape[1]
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
q = self.query(x)
n_ctx = q.shape[1]
assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
head_dim = q.shape[-1] // self.n_head
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
return self.out(wv)
class ResidualAttentionBlock:
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int]=None):
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp_ln(x).sequential(self.mlp)
return x.realize()
class AudioEncoder:
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
self.ln_post = nn.LayerNorm(n_audio_state)
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
self.encode = TinyJit(self.__call__)
def __call__(self, x):
x = self.conv1(x).gelu()
x = self.conv2(x).gelu()
x = x.permute(0, 2, 1)
x = x + self.positional_embedding[:x.shape[1]]
x = x.sequential(self.blocks)
x = self.ln_post(x)
return x.realize()
class TextDecoder:
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
self.max_tokens_to_sample = n_text_ctx // 2
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
self.ln = nn.LayerNorm(n_text_state)
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks]
self.start_output_tok = TinyJit(self.output_tok)
self.after_start_output_tok = TinyJit(self.output_tok)
# if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only')
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
seqlen = x.shape[-1]
x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen]
if pos == 0:
for block in (self.blocks if streaming else self.blocks_start_tok):
x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
return self.output_tok(x) if streaming else self.start_output_tok(x)
else:
for block in self.blocks_after_start_tok:
len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
x = block(x, mask=self.mask, len=len_v)
return self.after_start_output_tok(x)
def output_tok(self, x):
return (self.ln(x) @ self.token_embedding.weight.T).realize()
class Whisper:
def __init__(self, dims, batch_size=1):
self.encoder = AudioEncoder(**dims)
self.decoder = TextDecoder(**dims)
self.is_multilingual = dims["n_vocab"] == 51865
self.batch_size = batch_size
RATE = 16000
SEGMENT_SECONDS=30
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
"""
:param waveforms: A list of possibly variable length 16000Hz audio samples
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
:return: mel spectrogram of the given waveforms
"""
def pad_or_trim(arr, target_len):
curr_len = len(arr)
if curr_len == target_len:
return arr
elif curr_len < target_len:
return np.pad(arr, (0, target_len - curr_len), 'constant')
else:
return arr[:target_len]
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
assert waveforms.shape[0] <= batch_size
if waveforms.shape[0] < batch_size:
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
magnitudes = np.absolute(stft[..., :-1]) ** 2
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
LANGUAGES = {
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
"he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian",
"th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu",
"fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
"gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian",
"be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole",
"ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy",
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
}
def get_encoding(encoding_name):
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
n_vocab = len(ranks)
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
n_vocab += len(specials)
import tiktoken
return tiktoken.Encoding(
name=encoding_name,
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens)
MODEL_URLS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
def init_whisper(model_name="tiny.en", batch_size=1):
assert MODEL_URLS[model_name] is not None
filename = fetch(MODEL_URLS[model_name])
state = torch_load(filename)
model = Whisper(state['dims'], batch_size)
load_state_dict(model, state['model_state_dict'], strict=False)
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
return model, enc
def load_file_waveform(filename):
waveform, _ = librosa.load(filename, sr=RATE)
return waveform
def transcribe_file(model, enc, filename):
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
def transcribe_waveform(model, enc, waveforms, truncate=False):
"""
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
"""
N_audio = len(waveforms)
log_spec = prep_audio(waveforms, model.batch_size, truncate)
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
raise Exception("Multi-segment transcription not supported with batch audio input")
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
if model.is_multilingual:
# TODO detect language
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
start_tokens.append(language_token)
start_tokens.append(enc._special_tokens["<|transcribe|>"])
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
transcription_start_index = len(start_tokens)
eot = enc._special_tokens["<|endoftext|>"]
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
pos = 0
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
if curr_frame > 0:
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt = np.concatenate((
[enc._special_tokens["<|startofprev|>"]],
transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:],
start_tokens))
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
transcription_start_index = len(curr_segment_tokens[0])
for i in range(model.decoder.max_tokens_to_sample):
out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
pos = curr_segment_tokens.shape[-1] - 1
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
if (curr_segment_tokens[:, -1] == eot).all():
break
for i, t in enumerate(curr_segment_tokens):
eot_index = np.where(t == eot)[0]
eot_index = None if len(eot_index) == 0 else eot_index[0]
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
CHUNK = 1600
RECORD_SECONDS = 10
def listener(q):
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
print("listening")
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
q.put(waveform)
print("done listening")
if __name__ == "__main__":
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
if len(sys.argv) > 1:
print(transcribe_file(model, enc, sys.argv[1]))
else:
# online
q = multiprocessing.Queue()
p = multiprocessing.Process(target=listener, args=(q,))
p.daemon = True
p.start()
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
total = None
did_read = False
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
while not q.empty() or total is None:
waveform = q.get()
if total is None: total = waveform
else: total = np.concatenate([total, waveform])
did_read = True
if did_read:
log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
encoded_audio = model.encoder.encode(Tensor(log_spec))
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
idx = int(out[0,-1].argmax().numpy().item())
lst.append(idx)
dec = enc.decode(lst)
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
if dec.endswith("<|endoftext|>"):
lst.pop()