-
Notifications
You must be signed in to change notification settings - Fork 29
/
main.py
116 lines (91 loc) · 3.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import logging
import os
import tempfile
from typing import Iterator
import leapfrogai_sdk as lfai
from faster_whisper import WhisperModel
logger = logging.getLogger(__name__)
model_path = os.environ.get("LFAI_MODEL_PATH", ".model")
GPU_ENABLED = True if int(os.environ.get("GPU_REQUEST", 0)) > 0 else False
def make_transcribe_request(filename, task, language, temperature, prompt):
device = "cuda" if GPU_ENABLED else "cpu"
model = WhisperModel(model_path, device=device, compute_type="float32")
# Prepare kwargs with non-None values
kwargs = {}
if task:
if task in ["transcribe", "translate"]:
kwargs["task"] = task
else:
logger.error(f"Task {task} is not supported")
return {"text": ""}
if language:
if language in model.supported_languages:
kwargs["language"] = language
else:
logger.error(f"Language {language} is not supported")
if temperature:
kwargs["temperature"] = temperature
if prompt:
kwargs["initial_prompt"] = prompt
try:
# Call transcribe with only non-None parameters
segments, info = model.transcribe(filename, beam_size=5, **kwargs)
except Exception as e:
logger.error(f"Error transcribing audio: {e}")
return {"text": ""}
output = ""
for segment in segments:
output += segment.text
logger.info("Completed " + filename)
return {"text": output}
def call_whisper(
request_iterator: Iterator[lfai.AudioRequest], task: str
) -> lfai.AudioResponse:
data = bytearray()
prompt = ""
temperature = 0.0
inputLanguage = "en"
for request in request_iterator:
if (
request.metadata.prompt
and request.metadata.temperature
and request.metadata.inputlanguage
):
prompt = request.metadata.prompt
temperature = request.metadata.temperature
inputLanguage = request.metadata.inputlanguage
continue
data.extend(request.chunk_data)
with tempfile.NamedTemporaryFile("wb") as f:
f.write(data)
result = make_transcribe_request(
f.name, task, inputLanguage, temperature, prompt
)
text = str(result["text"])
if task == "transcribe":
logger.info("Transcription complete!")
elif task == "translate":
logger.info("Translation complete!")
return lfai.AudioResponse(text=text)
class Whisper(lfai.AudioServicer):
def Translate(
self,
request_iterator: Iterator[lfai.AudioRequest],
context: lfai.GrpcContext,
):
return call_whisper(request_iterator, "translate")
def Transcribe(
self,
request_iterator: Iterator[lfai.AudioRequest],
context: lfai.GrpcContext,
):
return call_whisper(request_iterator, "transcribe")
def Name(self, request, context):
return lfai.NameResponse(name="whisper")
async def main():
logging.basicConfig(level=logging.INFO)
logger.info(f"GPU_ENABLED = {GPU_ENABLED}")
await lfai.serve(Whisper())
if __name__ == "__main__":
asyncio.run(main())