From ab5df1a236a93222a736dbd831cf019dfe88f3ba Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Sun, 22 Dec 2024 11:19:02 -0800 Subject: [PATCH] feature complete gemini audio, transcription, and phrase endpointing demo --- .../22d-natural-conversation-gemini-audio.py | 181 +++++++++++------- src/pipecat/services/google.py | 13 +- 2 files changed, 117 insertions(+), 77 deletions(-) diff --git a/examples/foundational/22d-natural-conversation-gemini-audio.py b/examples/foundational/22d-natural-conversation-gemini-audio.py index 78d1271e6..2c6d21f92 100644 --- a/examples/foundational/22d-natural-conversation-gemini-audio.py +++ b/examples/foundational/22d-natural-conversation-gemini-audio.py @@ -57,6 +57,14 @@ logger.remove(0) logger.add(sys.stderr, level="DEBUG") +# TRANSCRIBER_MODEL = "gemini-1.5-flash-latest" +# CLASSIFIER_MODEL = "gemini-1.5-flash-latest" +# CONVERSATION_MODEL = "gemini-1.5-flash-latest" + +TRANSCRIBER_MODEL = "gemini-2.0-flash-exp" +CLASSIFIER_MODEL = "gemini-2.0-flash-exp" +CONVERSATION_MODEL = "gemini-2.0-flash-exp" + transcriber_system_instruction = """You are an audio transcriber. You are receiving audio from a user. Your job is to transcribe the input audio to text exactly as it was said by the user. @@ -347,6 +355,11 @@ class AudioAccumulator(FrameProcessor): + """Buffers user audio until the user stops speaking. + + Always pushes a fresh context with a single audio message. + """ + def __init__(self, **kwargs): super().__init__(**kwargs) self._audio_frames = [] @@ -376,14 +389,6 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._user_speaking_utterance_state = True elif isinstance(frame, UserStoppedSpeakingFrame): - if self._audio_frames[-1]: - fr = self._audio_frames[-1] - frame_duration = len(fr.audio) / 2 * fr.num_channels / fr.sample_rate - - logger.debug( - f"!!! Frame duration: ({len(fr.audio)}) ({fr.num_channels}) ({fr.sample_rate}) {frame_duration}" - ) - data = b"".join(frame.audio for frame in self._audio_frames) logger.debug( f"Processing audio buffer seconds: ({len(self._audio_frames)}) ({len(data)}) {len(data) / 2 / 16000}" @@ -415,6 +420,12 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): class CompletenessCheck(FrameProcessor): + """Checks the result of the classifier LLM to determine if the user has finished speaking. + + Triggers the notifier if the user has finished speaking. Also triggers the notifier if an + idle timeout is reached. + """ + wait_time = 5.0 def __init__(self, notifier: BaseNotifier, audio_accumulator: AudioAccumulator, **kwargs): @@ -427,12 +438,13 @@ def __init__(self, notifier: BaseNotifier, audio_accumulator: AudioAccumulator, async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, TextFrame) and frame.text.startswith("YES"): + if isinstance(frame, UserStartedSpeakingFrame): + if self._idle_task: + self._idle_task.cancel() + elif isinstance(frame, TextFrame) and frame.text.startswith("YES"): logger.debug("Completeness check YES") if self._idle_task: - logger.debug(f"CompletenessCheck idle wait CANCEL") self._idle_task.cancel() - self._idle_task = None await self.push_frame(UserStoppedSpeakingFrame()) await self._audio_accumulator.reset() await self._notifier.notify() @@ -443,7 +455,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if self._wakeup_time: self._wakeup_time = time.time() + self.wait_time else: - logger.debug("CompletenessCheck idle wait START") + # logger.debug("!!! CompletenessCheck idle wait START") self._wakeup_time = time.time() + self.wait_time self._idle_task = self.get_event_loop().create_task(self._idle_task_handler()) @@ -451,16 +463,87 @@ async def _idle_task_handler(self): try: while time.time() < self._wakeup_time: await asyncio.sleep(0.01) - logger.debug(f"CompletenessCheck idle wait OVER") + # logger.debug(f"!!! CompletenessCheck idle wait OVER") + await self._audio_accumulator.reset() await self._notifier.notify() except asyncio.CancelledError: + # logger.debug(f"!!! CompletenessCheck idle wait CANCEL") pass except Exception as e: logger.error(f"CompletenessCheck idle wait error: {e}") raise e + finally: + # logger.debug(f"!!! CompletenessCheck idle wait FINALLY") + self._wakeup_time = 0 + self._idle_task = None + + +class UserAggregatorBuffer(LLMResponseAggregator): + """Buffers the output of the transcription LLM. Used by the bot output gate.""" + + def __init__(self, **kwargs): + super().__init__( + messages=None, + role=None, + start_frame=LLMFullResponseStartFrame, + end_frame=LLMFullResponseEndFrame, + accumulator_frame=TextFrame, + handle_interruptions=True, + expect_stripped_words=False, + ) + self._transcription = "" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + # parent method pushes frames + if isinstance(frame, UserStartedSpeakingFrame): + self._transcription = "" + + async def _push_aggregation(self): + if self._aggregation: + self._transcription = self._aggregation + self._aggregation = "" + + logger.debug(f"[Transcription] {self._transcription}") + + async def wait_for_transcription(self): + while not self._transcription: + await asyncio.sleep(0.01) + tx = self._transcription + self._transcription = "" + return tx + + +class ConversationAudioContextAssembler(FrameProcessor): + """Takes the single-message context generated by the AudioAccumulator and adds it to the conversation LLM's context.""" + + def __init__(self, context: OpenAILLMContext, **kwargs): + super().__init__(**kwargs) + self._context = context + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # We must not block system frames. + if isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + return + + if isinstance(frame, OpenAILLMContextFrame): + GoogleLLMContext.upgrade_to_google(self._context) + last_message = frame.context.messages[-1] + self._context._messages.append(last_message) + await self.push_frame(OpenAILLMContextFrame(context=self._context)) class OutputGate(FrameProcessor): + """Buffers output frames until the notifier is triggered. + + When the notifier fires, waits until a transcription is ready, then: + 1. Replaces the last user audio message with the transcription. + 2. Flushes the frames buffer. + """ + def __init__( self, notifier: BaseNotifier, @@ -501,6 +584,13 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) return + if isinstance(frame, LLMFullResponseStartFrame): + # Remove the audio message from the context. We will never need it again. + # If the completeness check fails, a new audio message will be appended to the context. + # If the completeness check succeeds, our notifier will fire and we will append the + # transcription to the context. + self._context._messages.pop() + if self._gate_open: await self.push_frame(frame, direction) return @@ -517,16 +607,13 @@ async def _stop(self): async def _gate_task_handler(self): while True: - # logger.debug("!!! Waiting for notifier") try: await self._notifier.wait() - # logger.debug("!!! Notified") - transcription = await self._transcription_buffer.wait_for_transcription() - - last_message = self._context.messages[-1] - if last_message.role == "user": - last_message.parts = [glm.Part(text=transcription)] + transcription = await self._transcription_buffer.wait_for_transcription() or "-" + self._context._messages.append( + glm.Content(role="user", parts=[glm.Part(text=transcription)]) + ) self.open_gate() for frame, direction in self._frames_buffer: @@ -540,54 +627,6 @@ async def _gate_task_handler(self): break -class ConversationAudioContextAssembler(FrameProcessor): - def __init__(self, context: OpenAILLMContext, **kwargs): - super().__init__(**kwargs) - self._context = context - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - # We must not block system frames. - if isinstance(frame, SystemFrame): - await self.push_frame(frame, direction) - return - - if isinstance(frame, OpenAILLMContextFrame): - GoogleLLMContext.upgrade_to_google(self._context) - last_message = frame.context.messages[-1] - self._context._messages.append(last_message) - await self.push_frame(OpenAILLMContextFrame(context=self._context)) - - -class UserAggregatorBuffer(LLMResponseAggregator): - def __init__(self, **kwargs): - super().__init__( - messages=None, - role=None, - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - expect_stripped_words=False, - ) - self._transcription = "" - - async def _push_aggregation(self): - if self._aggregation: - self._transcription = self._aggregation - self._aggregation = "" - - logger.debug(f"[Transcription] {self._transcription}") - - async def wait_for_transcription(self): - while not self._transcription: - await asyncio.sleep(0.01) - tx = self._transcription - self._transcription = "" - return tx - - async def main(): async with aiohttp.ClientSession() as session: (room_url, _) = await configure(session) @@ -613,7 +652,7 @@ async def main(): # This is the LLM that will transcribe user speech. tx_llm = GoogleLLMService( name="Transcriber", - model="gemini-2.0-flash-exp", + model=TRANSCRIBER_MODEL, api_key=os.getenv("GOOGLE_API_KEY"), temperature=0.0, system_instruction=transcriber_system_instruction, @@ -622,7 +661,7 @@ async def main(): # This is the LLM that will classify user speech as complete or incomplete. classifier_llm = GoogleLLMService( name="Classifier", - model="gemini-2.0-flash-exp", + model=CLASSIFIER_MODEL, api_key=os.getenv("GOOGLE_API_KEY"), temperature=0.0, system_instruction=classifier_system_instruction, @@ -631,7 +670,7 @@ async def main(): # This is the regular LLM that responds conversationally. conversation_llm = GoogleLLMService( name="Conversation", - model="gemini-2.0-flash-exp", + model=CONVERSATION_MODEL, api_key=os.getenv("GOOGLE_API_KEY"), system_instruction=conversation_system_instruction, ) diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index 091c4df06..d724d2776 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -635,7 +635,7 @@ async def _process_context(self, context: OpenAILLMContext): messages = context.messages if context.system_message and self._system_instruction != context.system_message: - # logger.debug(f"System instruction changed: {context.system_message}") + logger.debug(f"System instruction changed: {context.system_message}") self._system_instruction = context.system_message self._create_client() @@ -673,15 +673,16 @@ async def _process_context(self, context: OpenAILLMContext): await self.stop_ttfb_metrics() if response.usage_metadata: + # Use only the prompt token count from the response object prompt_tokens = response.usage_metadata.prompt_token_count - completion_tokens = response.usage_metadata.candidates_token_count - total_tokens = response.usage_metadata.total_token_count + total_tokens = prompt_tokens async for chunk in response: if chunk.usage_metadata: - prompt_tokens += response.usage_metadata.prompt_token_count - completion_tokens += response.usage_metadata.candidates_token_count - total_tokens += response.usage_metadata.total_token_count + # Use only the completion_tokens from the chunks. Prompt tokens are already counted and + # are repeated here. + completion_tokens += chunk.usage_metadata.candidates_token_count + total_tokens += chunk.usage_metadata.candidates_token_count try: for c in chunk.parts: if c.text: