forked from alesaccoia/VoiceStreamAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
184 lines (136 loc) · 6.7 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
VoiceStreamAI Server: Real-time audio transcription using self-hosted Whisper and WebSocket
Contributors:
- Alessandro Saccoia - [email protected]
"""
import asyncio
import websockets
import uuid
import json
import wave
import os
import time
import logging
from transformers import pipeline
from pyannote.core import Segment
from pyannote.audio import Model
from pyannote.audio.pipelines import VoiceActivityDetection
HOST = 'localhost'
PORT = 8765
SAMPLING_RATE = 16000
AUDIO_CHANNELS = 1
SAMPLES_WIDTH = 2 # int16
DEBUG = True
VAD_AUTH_TOKEN = "FILL ME" # get your key here -> https://huggingface.co/pyannote/segmentation
DEFAULT_CLIENT_CONFIG = {
"language" : None, # multilingual
"chunk_length_seconds" : 5,
"chunk_offset_seconds" : 1
}
audio_dir = "audio_files"
os.makedirs(audio_dir, exist_ok=True)
## ---------- INSTANTIATES VAD --------
model = Model.from_pretrained("pyannote/segmentation", use_auth_token=VAD_AUTH_TOKEN)
vad_pipeline = VoiceActivityDetection(segmentation=model)
vad_pipeline.instantiate({"onset": 0.5, "offset": 0.5, "min_duration_on": 0.3, "min_duration_off": 0.3})
## ---------- INSTANTIATES SPEECH --------
recognition_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3")
connected_clients = {}
client_buffers = {}
client_temp_buffers = {}
client_configs = {}
# Counter for each client to keep track of file numbers
file_counters = {}
async def transcribe_and_send(client_id, websocket, new_audio_data):
global file_counters
if DEBUG: print(f"Client ID {client_id}: new_audio_data length in seconds at transcribe_and_send: {float(len(new_audio_data)) / float(SAMPLING_RATE * SAMPLES_WIDTH)}")
# Initialize temporary buffer for new clients
if client_id not in client_temp_buffers:
client_temp_buffers[client_id] = bytearray()
if DEBUG: print(f"Client ID {client_id}: client_temp_buffers[client_id] length in seconds at transcribe_and_send: {float(len(client_temp_buffers[client_id])) / float(SAMPLING_RATE * SAMPLES_WIDTH)}")
# Add new audio data to the temporary buffer
old_audio_data = bytes(client_temp_buffers[client_id])
if DEBUG: print(f"Client ID {client_id}: old_audio_data length in seconds at transcribe_and_send: {float(len(old_audio_data)) / float(SAMPLING_RATE * SAMPLES_WIDTH)}")
audio_data = old_audio_data + new_audio_data
if DEBUG: print(f"Client ID {client_id}: audio_data length in seconds at transcribe_and_send: {float(len(audio_data)) / float(SAMPLING_RATE * SAMPLES_WIDTH)}")
# Initialize file counter for new clients
if client_id not in file_counters:
file_counters[client_id] = 0
# File path
file_name = f"{audio_dir}/{client_id}_{file_counters[client_id]}.wav"
if DEBUG: print(f"Client ID {client_id}: Filename : {file_name}")
file_counters[client_id] += 1
# Save the audio data
with wave.open(file_name, 'wb') as wav_file:
wav_file.setnchannels(AUDIO_CHANNELS)
wav_file.setsampwidth(SAMPLES_WIDTH)
wav_file.setframerate(SAMPLING_RATE)
wav_file.writeframes(audio_data)
# Measure VAD time
start_time_vad = time.time()
result = vad_pipeline(file_name)
vad_time = time.time() - start_time_vad
# Logging after VAD
if DEBUG: print(f"Client ID {client_id}: VAD result segments count: {len(result)}")
print(f"Client ID {client_id}: VAD inference time: {vad_time:.2f}")
if len(result) == 0: # this should happen just if there's no old audio data
os.remove(file_name)
client_temp_buffers[client_id].clear()
return
# Get last recognized segment
last_segment = None
for segment in result.itersegments():
last_segment = segment
if DEBUG: print(f"Client ID {client_id}: VAD last Segment end : {last_segment.end}")
# if the voice ends before chunk_offset_seconds process it all
if last_segment.end < (len(audio_data) / (SAMPLES_WIDTH * SAMPLING_RATE)) - int(client_configs[client_id]['chunk_offset_seconds']):
start_time_transcription = time.time()
if client_configs[client_id]['language'] is not None:
result = recognition_pipeline(file_name, generate_kwargs={"language": client_configs[client_id]['language']})
else:
result = recognition_pipeline(file_name)
transcription_time = time.time() - start_time_transcription
if DEBUG: print(f"Transcription Time: {transcription_time:.2f} seconds")
print(f"Client ID {client_id}: Transcribed : {result['text']}")
if result['text']:
await websocket.send(result['text'])
client_temp_buffers[client_id].clear() # Clear temp buffer after processing
else:
client_temp_buffers[client_id].clear()
client_temp_buffers[client_id].extend(audio_data)
if DEBUG: print(f"Skipping because {last_segment.end} falls after {(len(audio_data) / (SAMPLES_WIDTH * SAMPLING_RATE)) - int(client_configs[client_id]['chunk_offset_seconds'])}")
os.remove(file_name) # in the end always delete the created file
async def receive_audio(websocket, path):
client_id = str(uuid.uuid4())
connected_clients[client_id] = websocket
client_buffers[client_id] = bytearray()
client_configs[client_id] = DEFAULT_CLIENT_CONFIG
print(f"Client {client_id} connected")
try:
async for message in websocket:
if isinstance(message, bytes):
client_buffers[client_id].extend(message)
elif isinstance(message, str):
config = json.loads(message)
if config.get('type') == 'config':
client_configs[client_id] = config['data']
print(f"Config for {client_id}: {client_configs[client_id]}")
continue
else:
print(f"Unexpected message type from {client_id}")
# Process audio when enough data is received
if len(client_buffers[client_id]) > int(client_configs[client_id]['chunk_length_seconds']) * SAMPLING_RATE * SAMPLES_WIDTH:
if DEBUG: print(f"Client ID {client_id}: receive_audio calling transcribe_and_send with length: {len(client_buffers[client_id])}")
await transcribe_and_send(client_id, websocket, client_buffers[client_id])
client_buffers[client_id].clear()
except websockets.ConnectionClosed as e:
print(f"Connection with {client_id} closed: {e}")
finally:
del connected_clients[client_id]
del client_buffers[client_id]
async def main():
async with websockets.serve(receive_audio, HOST, PORT):
print(f"WebSocket server started on ws://{HOST}:{PORT}")
await asyncio.Future()
if __name__ == "__main__":
asyncio.run(main())