From 5ce3af50af56a5ecbe91789f58a45d1869390738 Mon Sep 17 00:00:00 2001 From: Mark Parker Date: Fri, 22 Sep 2023 16:30:32 +0200 Subject: [PATCH] improve vosk parse and doc all result types --- stark/interfaces/vosk.py | 75 ++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/stark/interfaces/vosk.py b/stark/interfaces/vosk.py index c9c9da2..60cb2f9 100644 --- a/stark/interfaces/vosk.py +++ b/stark/interfaces/vosk.py @@ -9,12 +9,37 @@ import sounddevice import vosk +from pydantic import BaseModel, Field, ValidationError from .protocols import SpeechRecognizer, SpeechRecognizerDelegate vosk.SetLogLevel(-1) +class KaldiTranscriptionWord(BaseModel): + word: str + start: float + end: float + conf: float | None = None # only for KaldiMBR + +class KaldiMBR(BaseModel): + text: str + result: list[KaldiTranscriptionWord] = Field(default_factory = list) + spk: list[float] + spk_frames: int + + @property + def confidence(self): + return sum(word.conf for word in self.result) / len(self.result) + +class KaldiTranscription(BaseModel): + text: str + result: list[KaldiTranscriptionWord] = Field(default_factory = list) + confidence: float + +class KaldiResult(BaseModel): + alternatives: list[KaldiTranscription] + class VoskSpeechRecognizer(SpeechRecognizer): _delegate: SpeechRecognizerDelegate | None = None @@ -67,6 +92,8 @@ def __init__(self, model_url: str, speaker_model_url: str | None = None): vosk_model = vosk.Model(model_path) speaker_model = vosk.SpkModel(speaker_model_path) if speaker_model_path else None self.kaldiRecognizer = vosk.KaldiRecognizer(vosk_model, self.samplerate) + self.kaldiRecognizer.SetMaxAlternatives(0) # 0 (default) returns KaldiMBR; 1+ returns KaldiResult (with bad confidence implementation) + self.kaldiRecognizer.SetWords(True) # needs to calculate MBR average confidence; (default: False) if speaker_model_url: self.kaldiRecognizer.SetSpkModel(speaker_model) @@ -115,41 +142,39 @@ async def _transcribe(self, data): if self.kaldiRecognizer.AcceptWaveform(data): self.last_partial_update_time = None - result = json.loads(self.kaldiRecognizer.Result()) + raw_json = self.kaldiRecognizer.Result() + text: str | None = None - if (string := result.get('text')): - self.last_result = string - await delegate.speech_recognizer_did_receive_final_result(string) + try: + result = KaldiMBR.parse_raw(raw_json) + text = result.text + # print('\nConfidence:', result.confidence) + except ValidationError: + try: + result = KaldiResult.parse_raw(raw_json) + transcription = result.alternatives[0] + text = transcription.text + except ValidationError: + text = json.loads(raw_json).get('text') + + if text: + # if result.spk: + # speaker, similarity = self._get_speaker(result.spk) + # print(f'\nSpeaker: {speaker} ({similarity * 100:.2f}%)\n') + + self.last_result = text + await delegate.speech_recognizer_did_receive_final_result(text) else: self.last_result = None await delegate.speech_recognizer_did_receive_empty_result() - - # if spk := result.get('spk'): - # speaker, similarity = self._get_speaker(spk) - # print(f'\nSpeaker: {speaker} ({similarity * 100:.2f}%)\n') else: result = json.loads(self.kaldiRecognizer.PartialResult()) - + # partial always returns {"partial": "..."} if (string := result.get('partial')) and string != self.last_partial_result: self.last_partial_result = string self.last_partial_update_time = datetime.now() await delegate.speech_recognizer_did_receive_partial_result(string) - - # Check for partial results timeout - # TODO: (was bug with stucked partial results, need to check again) - # TODO: check last string didn't change all timeout - # if not self.last_partial_update_time: - # return - - # if datetime.now() - self.last_partial_update_time > timedelta(seconds = 1): - # print('\nPartial timeout') - # self.last_partial_update_time = None - - # if self.last_partial_result: - # await delegate.speech_recognizer_did_receive_final_result(self.last_partial_result) - # self.kaldiRecognizer.Reset() # avoid duplicate results - # self.last_partial_result = '' def _audio_input_callback(self, indata, frames, time, status): if not self.is_recognizing: return @@ -170,7 +195,7 @@ def _get_speaker(self, vector: list[int]) -> tuple[int, float]: if not best_similarity or best_similarity < self._speaker_trashold: matched_speaker_id = len(self._stored_speakers) self._stored_speakers[matched_speaker_id] = vector - print(f'New speaker: {matched_speaker_id}, similarity: {best_similarity * 100:.2f}%') + # print(f'New speaker: {matched_speaker_id}, similarity: {best_similarity * 100:.2f}%') best_similarity = 1 return cast(int, matched_speaker_id), best_similarity