forked from defenseunicorns/leapfrogai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
94 lines (69 loc) · 2.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import asyncio
import logging
import os
import tempfile
from typing import Iterator
import leapfrogai_sdk as lfai
from faster_whisper import WhisperModel
logger = logging.getLogger(__name__)
model_path = ".model"
GPU_ENABLED = True if int(os.environ.get("GPU_REQUEST", 0)) > 0 else False
def make_transcribe_request(filename, task, language, temperature, prompt):
device = "cuda" if GPU_ENABLED else "cpu"
model = WhisperModel(model_path, device=device, compute_type="float32")
segments, info = model.transcribe(filename, task=task, beam_size=5)
output = ""
for segment in segments:
output += segment.text
logger.info("Completed " + filename)
return {"text": output}
def call_whisper(
request_iterator: Iterator[lfai.AudioRequest], task: str
) -> lfai.AudioResponse:
data = bytearray()
prompt = ""
temperature = 0.0
inputLanguage = "en"
for request in request_iterator:
if (
request.metadata.prompt
and request.metadata.temperature
and request.metadata.inputlanguage
):
prompt = request.metadata.prompt
temperature = request.metadata.temperature
inputLanguage = request.metadata.inputlanguage
continue
data.extend(request.chunk_data)
with tempfile.NamedTemporaryFile("wb") as f:
f.write(data)
result = make_transcribe_request(
f.name, task, inputLanguage, temperature, prompt
)
text = str(result["text"])
if task == "transcribe":
logger.info("Transcription complete!")
elif task == "translate":
logger.info("Translation complete!")
return lfai.AudioResponse(text=text)
class Whisper(lfai.AudioServicer):
def Translate(
self,
request_iterator: Iterator[lfai.AudioRequest],
context: lfai.GrpcContext,
):
return call_whisper(request_iterator, "translate")
def Transcribe(
self,
request_iterator: Iterator[lfai.AudioRequest],
context: lfai.GrpcContext,
):
return call_whisper(request_iterator, "transcribe")
def Name(self, request, context):
return lfai.NameResponse(name="whisper")
async def main():
logging.basicConfig(level=logging.INFO)
logger.info(f"GPU_ENABLED = {GPU_ENABLED}")
await lfai.serve(Whisper())
if __name__ == "__main__":
asyncio.run(main())