Skip to content

Commit

Permalink
improve vosk
Browse files Browse the repository at this point in the history
parse and doc all result types
  • Loading branch information
MarkParker5 committed Sep 22, 2023
1 parent 7a6d2ff commit 5ce3af5
Showing 1 changed file with 50 additions and 25 deletions.
75 changes: 50 additions & 25 deletions stark/interfaces/vosk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,37 @@

import sounddevice
import vosk
from pydantic import BaseModel, Field, ValidationError

from .protocols import SpeechRecognizer, SpeechRecognizerDelegate


vosk.SetLogLevel(-1)

class KaldiTranscriptionWord(BaseModel):
word: str
start: float
end: float
conf: float | None = None # only for KaldiMBR

class KaldiMBR(BaseModel):
text: str
result: list[KaldiTranscriptionWord] = Field(default_factory = list)
spk: list[float]
spk_frames: int

@property
def confidence(self):
return sum(word.conf for word in self.result) / len(self.result)

class KaldiTranscription(BaseModel):
text: str
result: list[KaldiTranscriptionWord] = Field(default_factory = list)
confidence: float

class KaldiResult(BaseModel):
alternatives: list[KaldiTranscription]

class VoskSpeechRecognizer(SpeechRecognizer):

_delegate: SpeechRecognizerDelegate | None = None
Expand Down Expand Up @@ -67,6 +92,8 @@ def __init__(self, model_url: str, speaker_model_url: str | None = None):
vosk_model = vosk.Model(model_path)
speaker_model = vosk.SpkModel(speaker_model_path) if speaker_model_path else None
self.kaldiRecognizer = vosk.KaldiRecognizer(vosk_model, self.samplerate)
self.kaldiRecognizer.SetMaxAlternatives(0) # 0 (default) returns KaldiMBR; 1+ returns KaldiResult (with bad confidence implementation)
self.kaldiRecognizer.SetWords(True) # needs to calculate MBR average confidence; (default: False)
if speaker_model_url:
self.kaldiRecognizer.SetSpkModel(speaker_model)

Expand Down Expand Up @@ -115,41 +142,39 @@ async def _transcribe(self, data):

if self.kaldiRecognizer.AcceptWaveform(data):
self.last_partial_update_time = None
result = json.loads(self.kaldiRecognizer.Result())
raw_json = self.kaldiRecognizer.Result()
text: str | None = None

if (string := result.get('text')):
self.last_result = string
await delegate.speech_recognizer_did_receive_final_result(string)
try:
result = KaldiMBR.parse_raw(raw_json)
text = result.text
# print('\nConfidence:', result.confidence)
except ValidationError:
try:
result = KaldiResult.parse_raw(raw_json)
transcription = result.alternatives[0]
text = transcription.text
except ValidationError:
text = json.loads(raw_json).get('text')

if text:
# if result.spk:
# speaker, similarity = self._get_speaker(result.spk)
# print(f'\nSpeaker: {speaker} ({similarity * 100:.2f}%)\n')

self.last_result = text
await delegate.speech_recognizer_did_receive_final_result(text)
else:
self.last_result = None
await delegate.speech_recognizer_did_receive_empty_result()

# if spk := result.get('spk'):
# speaker, similarity = self._get_speaker(spk)
# print(f'\nSpeaker: {speaker} ({similarity * 100:.2f}%)\n')

else:
result = json.loads(self.kaldiRecognizer.PartialResult())

# partial always returns {"partial": "..."}
if (string := result.get('partial')) and string != self.last_partial_result:
self.last_partial_result = string
self.last_partial_update_time = datetime.now()
await delegate.speech_recognizer_did_receive_partial_result(string)

# Check for partial results timeout
# TODO: (was bug with stucked partial results, need to check again)
# TODO: check last string didn't change all timeout
# if not self.last_partial_update_time:
# return

# if datetime.now() - self.last_partial_update_time > timedelta(seconds = 1):
# print('\nPartial timeout')
# self.last_partial_update_time = None

# if self.last_partial_result:
# await delegate.speech_recognizer_did_receive_final_result(self.last_partial_result)
# self.kaldiRecognizer.Reset() # avoid duplicate results
# self.last_partial_result = ''

def _audio_input_callback(self, indata, frames, time, status):
if not self.is_recognizing: return
Expand All @@ -170,7 +195,7 @@ def _get_speaker(self, vector: list[int]) -> tuple[int, float]:
if not best_similarity or best_similarity < self._speaker_trashold:
matched_speaker_id = len(self._stored_speakers)
self._stored_speakers[matched_speaker_id] = vector
print(f'New speaker: {matched_speaker_id}, similarity: {best_similarity * 100:.2f}%')
# print(f'New speaker: {matched_speaker_id}, similarity: {best_similarity * 100:.2f}%')
best_similarity = 1

return cast(int, matched_speaker_id), best_similarity
Expand Down

0 comments on commit 5ce3af5

Please sign in to comment.