From c989c9c16dff228078ed8d328dd5df24450fabd7 Mon Sep 17 00:00:00 2001 From: edgar_git Date: Tue, 10 Sep 2024 10:02:31 +0200 Subject: [PATCH 1/5] LLM user frame processor with tests --- .../processors/aggregators/llm_response.py | 127 ++++++++++++++ .../test_LLM_user_context_aggregator.py | 155 ++++++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 479746471..7650d7c4f 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -318,3 +318,130 @@ def __init__(self, context: OpenAILLMContext): accumulator_frame=TranscriptionFrame, interim_accumulator_frame=InterimTranscriptionFrame, ) + # CUSTOM CODE: this variable remembers if we prompted the LLM + self.sent_aggregation_after_last_interruption = False + + # Relevant functions: + # LLMContextAggregator.async def _push_aggregation(self) + # and + # def LLMResponseAggregator._reset(self): + + # The original pipecat implementation is in: + # LLMResponseAggregator.process_frame + + # Use cases implemented: + # + # S: Start, E: End, T: Transcription, I: Interim, X: Text + # + # S E -> None + # S T E -> T + # S I T E -> T + # S I E T -> T + # S I E I T -> T + # S E T -> T + # S E I T -> T + # + # S I E T1 I T2 -> T1 + # + # and T2 would be dropped. + + # We have: + # S = UserStartedSpeakingFrame, + # E = UserStoppedSpeakingFrame, + # T = TranscriptionFrame, + # I = InterimTranscriptionFrame + + # Cases we want to handle: + # - Make sure we never delete some aggregation as it is something said by the user + # - Solves case: S T1 I E S T2 E where we lose T1 + # - Solve case: S T E Bot T (without E S) as the VAD is not activated (yeah case) + # - Solve case: S E T1 T2 where T2 is lost. (variation from above) + # For the last case we also send StartInterruptionFrame for making sure that the reprompt of the LLM does not make weird repeating messages. + + # So the cases would be: + # S E -> None + # S T E -> T + # S I T E -> T + # S I E T -> T + # S I E I T -> T + # S E T -> T + # S E I T -> T + # S T1 I E S T2 E -> (T1 T2) + # S I E T1 I T2 -> T1 Interruption T2 + # S T1 E T2 -> T1 Interruption T2 + # S E T1 B T2 -> T1 Bot Interruption T2 + # S E T1 T2 -> T1 Interruption T2 + # see the tests at test_LLM_user_context_aggregator + async def process_frame(self, frame: Frame, direction: FrameDirection): + await FrameProcessor.process_frame(self, frame, direction) + + send_aggregation = False + + if isinstance(frame, self._start_frame): + # CUSTOM CODE: dont _aggregation = "" + # self._aggregation = "" + self._aggregating = True + self._seen_start_frame = True + self._seen_end_frame = False + # CUSTOM CODE: _seen_interim_results should be updated by interimframe and accumulator frame only + # self._seen_interim_results = False + await self.push_frame(frame, direction) + elif isinstance(frame, self._end_frame): + self._seen_end_frame = True + self._seen_start_frame = False + + # We might have received the end frame but we might still be + # aggregating (i.e. we have seen interim results but not the final + # text). + self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 + + # Send the aggregation if we are not aggregating anymore (i.e. no + # more interim results received). + send_aggregation = not self._aggregating + await self.push_frame(frame, direction) + elif isinstance(frame, self._accumulator_frame): + # CUSTOM CODE: send interruption without VAD + if self.sent_aggregation_after_last_interruption: + await self.push_frame(StartInterruptionFrame()) + self.sent_aggregation_after_last_interruption = False + + # CUSTOM CODE: do not require _aggregating so we do not lose frames + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + # We have recevied a complete sentence, so if we have seen the + # end frame and we were still aggregating, it means we should + # send the aggregation. + # CUSTOM CODE: important thing is not see start frame and not end frame (so user is still speaking) + send_aggregation = not self._seen_start_frame + # We just got our final result, so let's reset interim results. + self._seen_interim_results = False + elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + # CUSTOM CODE: send interruption without VAD + if self.sent_aggregation_after_last_interruption: + await self.push_frame(StartInterruptionFrame()) + self.sent_aggregation_after_last_interruption = False + self._seen_interim_results = True + elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame): + # CUSTOM CODE: manage new interruptions + self.sent_aggregation_after_last_interruption = False + await self._push_aggregation() + # Reset anyways + self._reset() + await self.push_frame(frame, direction) + elif isinstance(frame, LLMMessagesAppendFrame): + self._messages.extend(frame.messages) + messages_frame = LLMMessagesFrame(self._messages) + await self.push_frame(messages_frame) + elif isinstance(frame, LLMMessagesUpdateFrame): + # We push the frame downstream so the assistant aggregator gets + # updated as well. + await self.push_frame(frame) + # We can now reset this one. + self._reset() + self._messages = frame.messages + messages_frame = LLMMessagesFrame(self._messages) + await self.push_frame(messages_frame) + else: + await self.push_frame(frame, direction) + + if send_aggregation: + await self._push_aggregation() diff --git a/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py b/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py new file mode 100644 index 000000000..6a0028d54 --- /dev/null +++ b/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py @@ -0,0 +1,155 @@ +# tests/test_custom_user_context.py + +"""Tests for CustomLLMUserContextAggregator""" + +import unittest + + +from pipecat.frames.frames import ( + Frame, + TranscriptionFrame, + InterimTranscriptionFrame, + StartInterruptionFrame, + StopInterruptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +# Note that UserStartedSpeakingFrame always come with StartInterruptionFrame +# and UserStoppedSpeakingFrame always come with StopInterruptionFrame +# S E -> None +# S T E -> T +# S I T E -> T +# S I E T -> T +# S I E I T -> T +# S E T -> T +# S E I T -> T +# S T1 I E S T2 E -> (T1 T2) +# S I E T1 I T2 -> T1 Interruption T2 +# S T1 E T2 -> T1 Interruption T2 +# S E T1 B T2 -> T1 Bot Interruption T2 +# S E T1 T2 -> T1 Interruption T2 + + +class StoreFrameProcessor(FrameProcessor): + def __init__(self, storage: list[Frame]) -> None: + super().__init__() + self.storage = storage + async def process_frame(self, frame: Frame, direction: FrameDirection): + self.storage.append(frame) + +async def make_test(frames_to_send, expected_returned_frames): + context_aggregator = LLMUserContextAggregator(OpenAILLMContext( + messages=[{"role": "", "content": ""}] + )) + storage = [] + storage_processor = StoreFrameProcessor(storage) + context_aggregator.link(storage_processor) + for frame in frames_to_send: + await context_aggregator.process_frame(frame, direction=FrameDirection.DOWNSTREAM) + print("storage") + for x in storage: + print(x) + print("expected_returned_frames") + for x in expected_returned_frames: + print(x) + assert len(storage) == len(expected_returned_frames) + for expected, real in zip(expected_returned_frames, storage): + assert isinstance(real, expected) + return storage + +class TestFrameProcessing(unittest.IsolatedAsyncioTestCase): + + # S E -> + async def test_s_e(self): + """S E case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame()] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame] + await make_test(frames_to_send, expected_returned_frames) + + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + await make_test(frames_to_send, expected_returned_frames) + + # S I T E -> T + async def test_s_i_t_e(self): + """S I T E case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + await make_test(frames_to_send, expected_returned_frames) + + # S I E T -> T + async def test_s_i_e_t(self): + """S I E T case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), TranscriptionFrame("", "", "")] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + await make_test(frames_to_send, expected_returned_frames) + + + # S I E I T -> T + async def test_s_i_e_i_t(self): + """S I E I T case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", "")] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + await make_test(frames_to_send, expected_returned_frames) + + # S E T -> T + async def test_s_e_t(self): + """S E case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), TranscriptionFrame("", "", "")] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + await make_test(frames_to_send, expected_returned_frames) + + # S E I T -> T + async def test_s_e_i_t(self): + """S E I T case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", "")] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + await make_test(frames_to_send, expected_returned_frames) + + # S T1 I E S T2 E -> (T1 T2) + async def test_s_t1_i_e_s_t2_e(self): + """S T1 I E S T2 E case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T1", "", ""), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), + StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T2", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, + StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + result = await make_test(frames_to_send, expected_returned_frames) + assert result[-1].context.messages[-1]["content"] == " T1 T2" + + # S I E T1 I T2 -> T1 Interruption T2 + async def test_s_i_e_t1_i_t2(self): + """S I E T1 I T2 case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("T2", "", ""),] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, + OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame] + result = await make_test(frames_to_send, expected_returned_frames) + assert result[-1].context.messages[-1]["content"] == " T1 T2" + + # S T1 E T2 -> T1 Interruption T2 + async def test_s_t1_e_t2(self): + """S T1 E T2 case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T1", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), + TranscriptionFrame("T2", "", ""),] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, + OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame] + result = await make_test(frames_to_send, expected_returned_frames) + assert result[-1].context.messages[-1]["content"] == " T1 T2" + + # S E T1 T2 -> T1 Interruption T2 + async def test_s_e_t1_t2(self): + """S E T1 T2 case""" + frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), TranscriptionFrame("T2", "", ""),] + expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, + OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame] + result = await make_test(frames_to_send, expected_returned_frames) + assert result[-1].context.messages[-1]["content"] == " T1 T2" \ No newline at end of file From 2dd56ba9924a63418094a59d2cdfcf95169509ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 10 Dec 2024 09:18:51 -0800 Subject: [PATCH 2/5] processors(llm_response): unify new use cases into base class --- CHANGELOG.md | 3 + .../processors/aggregators/llm_response.py | 333 +++++++----------- .../test_LLM_user_context_aggregator.py | 265 ++++++++++---- 3 files changed, 314 insertions(+), 287 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea272f9a1..7d5d028b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,6 +75,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed LLM response aggregators to support more uses cases such as delayed + transcriptions. + - Fixed an issue that could cause the bot to stop talking if there was a user interruption before getting any audio from the TTS service. diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 7650d7c4f..e62c7d0b7 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -7,6 +7,7 @@ from typing import List, Type from pipecat.frames.frames import ( + BotInterruptionFrame, Frame, InterimTranscriptionFrame, LLMFullResponseEndFrame, @@ -40,6 +41,7 @@ def __init__( interim_accumulator_frame: Type[TextFrame] | None = None, handle_interruptions: bool = False, expect_stripped_words: bool = True, # if True, need to add spaces between words + interrupt_double_accumulator: bool = True, # if True, interrupt if two or more accumulators are received ): super().__init__() @@ -51,8 +53,8 @@ def __init__( self._interim_accumulator_frame = interim_accumulator_frame self._handle_interruptions = handle_interruptions self._expect_stripped_words = expect_stripped_words + self._interrupt_double_accumulator = interrupt_double_accumulator - # Reset our accumulator state. self._reset() @property @@ -69,21 +71,20 @@ def role(self): # Use cases implemented: # - # S: Start, E: End, T: Transcription, I: Interim, X: Text + # S: Start, E: End, T: Transcription, I: Interim # - # S E -> None - # S T E -> X - # S I T E -> X - # S I E T -> X - # S I E I T -> X - # S E T -> X - # S E I T -> X - # - # The following case would not be supported: - # - # S I E T1 I T2 -> X - # - # and T2 would be dropped. + # S E -> None -> User started speaking but no transcription. + # S T E -> T -> Transcription between user started and stopped speaking. + # S E T -> T -> Transcription after user stopped speaking. + # S I T E -> T -> Transcription between user started and stopped speaking (with interims). + # S I E T -> T -> Transcription after user stopped speaking (with interims). + # S I E I T -> T -> Transcription after user stopped speaking (with interims). + # S E I T -> T -> Transcription after user stopped speaking (with interims). + # S T1 I E S T2 E -> "T1 T2" -> Merge two transcriptions if we got a first interim. + # S I E T1 I T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. + # S T1 E T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. + # S E T1 B T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. + # S E T1 T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -91,11 +92,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): send_aggregation = False if isinstance(frame, self._start_frame): - self._aggregation = "" self._aggregating = True self._seen_start_frame = True self._seen_end_frame = False - self._seen_interim_results = False await self.push_frame(frame, direction) elif isinstance(frame, self._end_frame): self._seen_end_frame = True @@ -109,23 +108,36 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Send the aggregation if we are not aggregating anymore (i.e. no # more interim results received). send_aggregation = not self._aggregating - await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): - if self._aggregating: - if self._expect_stripped_words: - self._aggregation += f" {frame.text}" if self._aggregation else frame.text - else: - self._aggregation += frame.text - # We have recevied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - send_aggregation = self._seen_end_frame + if ( + self._interrupt_double_accumulator + and self._sent_aggregation_after_last_interruption + ): + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + self._sent_aggregation_after_last_interruption = False + + if self._expect_stripped_words: + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + else: + self._aggregation += frame.text + + # If we haven't seen the start frame but we got an accumulator frame + # it means two things: it was develiver before the end frame or it + # was delivered late. In both cases so we want to send the + # aggregation. + send_aggregation = not self._seen_start_frame # We just got our final result, so let's reset interim results. self._seen_interim_results = False elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + if ( + self._interrupt_double_accumulator + and self._sent_aggregation_after_last_interruption + ): + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + self._sent_aggregation_after_last_interruption = False self._seen_interim_results = True - elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame): + elif isinstance(frame, StartInterruptionFrame) and self._handle_interruptions: await self._push_aggregation() # Reset anyways self._reset() @@ -142,6 +154,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if send_aggregation: await self._push_aggregation() + if isinstance(frame, self._end_frame): + await self.push_frame(frame, direction) + async def _push_aggregation(self): if len(self._aggregation) > 0: self._messages.append({"role": self._role, "content": self._aggregation}) @@ -150,6 +165,8 @@ async def _push_aggregation(self): # if the tasks gets cancelled we won't be able to clear things up. self._aggregation = "" + self._sent_aggregation_after_last_interruption = True + frame = LLMMessagesFrame(self._messages) await self.push_frame(frame) @@ -172,22 +189,11 @@ def _reset(self): self._seen_start_frame = False self._seen_end_frame = False self._seen_interim_results = False - - -class LLMAssistantResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): - super().__init__( - messages=messages, - role="assistant", - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - ) + self._sent_aggregation_after_last_interruption = False class LLMUserResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): + def __init__(self, messages: List[dict] = [], **kwargs): super().__init__( messages=messages, role="user", @@ -195,61 +201,21 @@ def __init__(self, messages: List[dict] = []): end_frame=UserStoppedSpeakingFrame, accumulator_frame=TranscriptionFrame, interim_accumulator_frame=InterimTranscriptionFrame, + **kwargs, ) -class LLMFullResponseAggregator(FrameProcessor): - """This class aggregates Text frames until it receives a - LLMFullResponseEndFrame, then emits the concatenated text as - a single text frame. - - given the following frames: - - TextFrame("Hello,") - TextFrame(" world.") - TextFrame(" I am") - TextFrame(" an LLM.") - LLMFullResponseEndFrame()] - - this processor will yield nothing for the first 4 frames, then - - TextFrame("Hello, world. I am an LLM.") - LLMFullResponseEndFrame() - - when passed the last frame. - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - ... else: - ... print(frame.__class__.__name__) - - >>> aggregator = LLMFullResponseAggregator() - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) - >>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame())) - Hello, world. I am an LLM. - LLMFullResponseEndFrame - """ - - def __init__(self): - super().__init__() - self._aggregation = "" - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, TextFrame): - self._aggregation += frame.text - elif isinstance(frame, LLMFullResponseEndFrame): - await self.push_frame(TextFrame(self._aggregation)) - await self.push_frame(frame) - self._aggregation = "" - else: - await self.push_frame(frame, direction) +class LLMAssistantResponseAggregator(LLMResponseAggregator): + def __init__(self, messages: List[dict] = [], **kwargs): + super().__init__( + messages=messages, + role="assistant", + start_frame=LLMFullResponseStartFrame, + end_frame=LLMFullResponseEndFrame, + accumulator_frame=TextFrame, + handle_interruptions=True, + **kwargs, + ) class LLMContextAggregator(LLMResponseAggregator): @@ -286,15 +252,14 @@ async def _push_aggregation(self): # if the tasks gets cancelled we won't be able to clear things up. self._aggregation = "" + self._sent_aggregation_after_last_interruption = True + frame = OpenAILLMContextFrame(self._context) await self.push_frame(frame) - # Reset our accumulator state. - self._reset() - class LLMAssistantContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True): + def __init__(self, context: OpenAILLMContext, **kwargs): super().__init__( messages=[], context=context, @@ -303,12 +268,12 @@ def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = T end_frame=LLMFullResponseEndFrame, accumulator_frame=TextFrame, handle_interruptions=True, - expect_stripped_words=expect_stripped_words, + **kwargs, ) class LLMUserContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext): + def __init__(self, context: OpenAILLMContext, **kwargs): super().__init__( messages=[], context=context, @@ -317,131 +282,69 @@ def __init__(self, context: OpenAILLMContext): end_frame=UserStoppedSpeakingFrame, accumulator_frame=TranscriptionFrame, interim_accumulator_frame=InterimTranscriptionFrame, + **kwargs, ) - # CUSTOM CODE: this variable remembers if we prompted the LLM - self.sent_aggregation_after_last_interruption = False - # Relevant functions: - # LLMContextAggregator.async def _push_aggregation(self) - # and - # def LLMResponseAggregator._reset(self): - # The original pipecat implementation is in: - # LLMResponseAggregator.process_frame +class LLMFullResponseAggregator(FrameProcessor): + """This class aggregates Text frames between LLMFullResponseStartFrame and + LLMFullResponseEndFrame, then emits the concatenated text as a single text + frame. + + given the following frames: - # Use cases implemented: - # - # S: Start, E: End, T: Transcription, I: Interim, X: Text - # - # S E -> None - # S T E -> T - # S I T E -> T - # S I E T -> T - # S I E I T -> T - # S E T -> T - # S E I T -> T - # - # S I E T1 I T2 -> T1 - # - # and T2 would be dropped. - - # We have: - # S = UserStartedSpeakingFrame, - # E = UserStoppedSpeakingFrame, - # T = TranscriptionFrame, - # I = InterimTranscriptionFrame - - # Cases we want to handle: - # - Make sure we never delete some aggregation as it is something said by the user - # - Solves case: S T1 I E S T2 E where we lose T1 - # - Solve case: S T E Bot T (without E S) as the VAD is not activated (yeah case) - # - Solve case: S E T1 T2 where T2 is lost. (variation from above) - # For the last case we also send StartInterruptionFrame for making sure that the reprompt of the LLM does not make weird repeating messages. - - # So the cases would be: - # S E -> None - # S T E -> T - # S I T E -> T - # S I E T -> T - # S I E I T -> T - # S E T -> T - # S E I T -> T - # S T1 I E S T2 E -> (T1 T2) - # S I E T1 I T2 -> T1 Interruption T2 - # S T1 E T2 -> T1 Interruption T2 - # S E T1 B T2 -> T1 Bot Interruption T2 - # S E T1 T2 -> T1 Interruption T2 - # see the tests at test_LLM_user_context_aggregator - async def process_frame(self, frame: Frame, direction: FrameDirection): - await FrameProcessor.process_frame(self, frame, direction) + LLMFullResponseStartFrame() + TextFrame("Hello,") + TextFrame(" world.") + TextFrame(" I am") + TextFrame(" an LLM.") + LLMFullResponseEndFrame() - send_aggregation = False + this processor will push, - if isinstance(frame, self._start_frame): - # CUSTOM CODE: dont _aggregation = "" - # self._aggregation = "" - self._aggregating = True - self._seen_start_frame = True - self._seen_end_frame = False - # CUSTOM CODE: _seen_interim_results should be updated by interimframe and accumulator frame only - # self._seen_interim_results = False - await self.push_frame(frame, direction) - elif isinstance(frame, self._end_frame): - self._seen_end_frame = True - self._seen_start_frame = False + LLMFullResponseStartFrame() + TextFrame("Hello, world. I am an LLM.") + LLMFullResponseEndFrame() - # We might have received the end frame but we might still be - # aggregating (i.e. we have seen interim results but not the final - # text). - self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 + when passed the last frame. - # Send the aggregation if we are not aggregating anymore (i.e. no - # more interim results received). - send_aggregation = not self._aggregating - await self.push_frame(frame, direction) - elif isinstance(frame, self._accumulator_frame): - # CUSTOM CODE: send interruption without VAD - if self.sent_aggregation_after_last_interruption: - await self.push_frame(StartInterruptionFrame()) - self.sent_aggregation_after_last_interruption = False - - # CUSTOM CODE: do not require _aggregating so we do not lose frames - self._aggregation += f" {frame.text}" if self._aggregation else frame.text - # We have recevied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - # CUSTOM CODE: important thing is not see start frame and not end frame (so user is still speaking) - send_aggregation = not self._seen_start_frame - # We just got our final result, so let's reset interim results. - self._seen_interim_results = False - elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): - # CUSTOM CODE: send interruption without VAD - if self.sent_aggregation_after_last_interruption: - await self.push_frame(StartInterruptionFrame()) - self.sent_aggregation_after_last_interruption = False - self._seen_interim_results = True - elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame): - # CUSTOM CODE: manage new interruptions - self.sent_aggregation_after_last_interruption = False - await self._push_aggregation() - # Reset anyways - self._reset() + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + ... else: + ... print(frame.__class__.__name__) + + >>> aggregator = LLMFullResponseAggregator() + >>> asyncio.run(print_frames(aggregator, LLMFullResponseStartFrame())) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) + >>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame())) + LLMFullResponseStartFrame + Hello, world. I am an LLM. + LLMFullResponseEndFrame + + """ + + def __init__(self): + super().__init__() + self._aggregation = "" + self._seen_start_frame = False + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, LLMFullResponseStartFrame): + self._seen_start_frame = True await self.push_frame(frame, direction) - elif isinstance(frame, LLMMessagesAppendFrame): - self._messages.extend(frame.messages) - messages_frame = LLMMessagesFrame(self._messages) - await self.push_frame(messages_frame) - elif isinstance(frame, LLMMessagesUpdateFrame): - # We push the frame downstream so the assistant aggregator gets - # updated as well. + elif isinstance(frame, LLMFullResponseEndFrame): + self._seen_start_frame = False + await self.push_frame(TextFrame(self._aggregation)) await self.push_frame(frame) - # We can now reset this one. - self._reset() - self._messages = frame.messages - messages_frame = LLMMessagesFrame(self._messages) - await self.push_frame(messages_frame) + self._aggregation = "" + elif isinstance(frame, TextFrame) and self._seen_start_frame: + self._aggregation += frame.text else: await self.push_frame(frame, direction) - - if send_aggregation: - await self._push_aggregation() diff --git a/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py b/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py index 6a0028d54..3498c1e58 100644 --- a/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py +++ b/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py @@ -1,16 +1,20 @@ # tests/test_custom_user_context.py -"""Tests for CustomLLMUserContextAggregator""" +"""Tests for CustomLLMUserContextAggregator""" +import asyncio import unittest +from dataclasses import dataclass +from typing import List +from pipecat.clocks.system_clock import SystemClock from pipecat.frames.frames import ( + ControlFrame, Frame, + StartFrame, TranscriptionFrame, InterimTranscriptionFrame, - StartInterruptionFrame, - StopInterruptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) @@ -35,121 +39,238 @@ # S E T1 T2 -> T1 Interruption T2 -class StoreFrameProcessor(FrameProcessor): - def __init__(self, storage: list[Frame]) -> None: +@dataclass +class EndTestFrame(ControlFrame): + pass + + +class QueuedFrameProcessor(FrameProcessor): + def __init__(self, queue: asyncio.Queue, ignore_start: bool = True): super().__init__() - self.storage = storage + self._queue = queue + self._ignore_start = ignore_start + async def process_frame(self, frame: Frame, direction: FrameDirection): - self.storage.append(frame) - -async def make_test(frames_to_send, expected_returned_frames): - context_aggregator = LLMUserContextAggregator(OpenAILLMContext( - messages=[{"role": "", "content": ""}] - )) - storage = [] - storage_processor = StoreFrameProcessor(storage) - context_aggregator.link(storage_processor) + await super().process_frame(frame, direction) + if self._ignore_start and isinstance(frame, StartFrame): + return + await self._queue.put(frame) + + +async def make_test( + frames_to_send: List[Frame], expected_returned_frames: List[type] +) -> List[Frame]: + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + + received = asyncio.Queue() + test_processor = QueuedFrameProcessor(received) + context_aggregator.link(test_processor) + + await context_aggregator.queue_frame(StartFrame(clock=SystemClock())) for frame in frames_to_send: await context_aggregator.process_frame(frame, direction=FrameDirection.DOWNSTREAM) - print("storage") - for x in storage: - print(x) - print("expected_returned_frames") - for x in expected_returned_frames: - print(x) - assert len(storage) == len(expected_returned_frames) - for expected, real in zip(expected_returned_frames, storage): + await context_aggregator.queue_frame(EndTestFrame()) + + received_frames: List[Frame] = [] + running = True + while running: + frame = await received.get() + running = not isinstance(frame, EndTestFrame) + if running: + received_frames.append(frame) + + assert len(received_frames) == len(expected_returned_frames) + for real, expected in zip(received_frames, expected_returned_frames): assert isinstance(real, expected) - return storage + return received_frames -class TestFrameProcessing(unittest.IsolatedAsyncioTestCase): - # S E -> +class TestFrameProcessing(unittest.IsolatedAsyncioTestCase): + # S E -> async def test_s_e(self): """S E case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame()] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + ] await make_test(frames_to_send, expected_returned_frames) # S T E -> T async def test_s_t_e(self): """S T E case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame("Hello", "", ""), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] await make_test(frames_to_send, expected_returned_frames) # S I T E -> T async def test_s_i_t_e(self): """S I T E case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] await make_test(frames_to_send, expected_returned_frames) # S I E T -> T async def test_s_i_e_t(self): """S I E T case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), TranscriptionFrame("", "", "")] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] await make_test(frames_to_send, expected_returned_frames) - # S I E I T -> T async def test_s_i_e_i_t(self): """S I E I T case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", "")] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This is", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] await make_test(frames_to_send, expected_returned_frames) # S E T -> T async def test_s_e_t(self): """S E case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), TranscriptionFrame("", "", "")] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] await make_test(frames_to_send, expected_returned_frames) # S E I T -> T async def test_s_e_i_t(self): """S E I T case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", "")] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] await make_test(frames_to_send, expected_returned_frames) - # S T1 I E S T2 E -> (T1 T2) + # S T1 I E S T2 E -> "T1 T2" async def test_s_t1_i_e_s_t2_e(self): """S T1 I E S T2 E case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T1", "", ""), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), - StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T2", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, - StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + UserStoppedSpeakingFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T2", "", ""), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] result = await make_test(frames_to_send, expected_returned_frames) - assert result[-1].context.messages[-1]["content"] == " T1 T2" + assert result[-1].context.messages[-1]["content"] == "T1 T2" # S I E T1 I T2 -> T1 Interruption T2 async def test_s_i_e_t1_i_t2(self): """S I E T1 I T2 case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), - TranscriptionFrame("T1", "", ""), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("T2", "", ""),] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame] + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("", "", ""), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + TranscriptionFrame("T2", "", ""), + ] + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] result = await make_test(frames_to_send, expected_returned_frames) - assert result[-1].context.messages[-1]["content"] == " T1 T2" - - # S T1 E T2 -> T1 Interruption T2 - async def test_s_t1_e_t2(self): - """S T1 E T2 case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T1", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), - TranscriptionFrame("T2", "", ""),] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame] - result = await make_test(frames_to_send, expected_returned_frames) - assert result[-1].context.messages[-1]["content"] == " T1 T2" - - # S E T1 T2 -> T1 Interruption T2 - async def test_s_e_t1_t2(self): - """S E T1 T2 case""" - frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), - TranscriptionFrame("T1", "", ""), TranscriptionFrame("T2", "", ""),] - expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame] - result = await make_test(frames_to_send, expected_returned_frames) - assert result[-1].context.messages[-1]["content"] == " T1 T2" \ No newline at end of file + assert result[-2].context.messages[-2]["content"] == "T1" + assert result[-1].context.messages[-1]["content"] == "T2" + + # # S T1 E T2 -> T1 Interruption T2 + # async def test_s_t1_e_t2(self): + # """S T1 E T2 case""" + # frames_to_send = [ + # UserStartedSpeakingFrame(), + # TranscriptionFrame("T1", "", ""), + # UserStoppedSpeakingFrame(), + # TranscriptionFrame("T2", "", ""), + # ] + # expected_returned_frames = [ + # UserStartedSpeakingFrame, + # UserStoppedSpeakingFrame, + # OpenAILLMContextFrame, + # OpenAILLMContextFrame, + # ] + # result = await make_test(frames_to_send, expected_returned_frames) + # assert result[-1].context.messages[-1]["content"] == " T1 T2" + + # # S E T1 T2 -> T1 Interruption T2 + # async def test_s_e_t1_t2(self): + # """S E T1 T2 case""" + # frames_to_send = [ + # UserStartedSpeakingFrame(), + # UserStoppedSpeakingFrame(), + # TranscriptionFrame("T1", "", ""), + # TranscriptionFrame("T2", "", ""), + # ] + # expected_returned_frames = [ + # UserStartedSpeakingFrame, + # UserStoppedSpeakingFrame, + # OpenAILLMContextFrame, + # OpenAILLMContextFrame, + # ] + # result = await make_test(frames_to_send, expected_returned_frames) + # assert result[-1].context.messages[-1]["content"] == " T1 T2" From ba6e9ed9ad7ddcee9b1ee574b435c6d86b4261f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 23 Dec 2024 17:57:45 -0800 Subject: [PATCH 3/5] processors(frame_processors): add a try/except when cancelling tasks This seems necessary because of how pytest works. If a task is cancelled, pytest will know the task has been cancelled even if # `asyncio.CancelledError` is handled internally in the task. --- src/pipecat/processors/frame_processor.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index e3dfe0bce..521918b69 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -311,8 +311,15 @@ def __create_push_task(self): self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler()) async def __cancel_push_task(self): - self.__push_frame_task.cancel() - await self.__push_frame_task + try: + self.__push_frame_task.cancel() + await self.__push_frame_task + except asyncio.CancelledError: + # TODO(aleix: Investigate why this is really needed. So far, this is + # necessary because of how pytest works. If a task is cancelled, + # pytest will know the task has been cancelled even if + # `asyncio.CancelledError` is handled internally in the task. + pass async def __push_frame_task_handler(self): running = True From dfdd536b2008d3f21966303a1fba5f9b4126d4d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 10 Dec 2024 09:17:57 -0800 Subject: [PATCH 4/5] improved unit tests / add a `run_test` function to test processors --- .github/workflows/tests.yaml | 2 +- pyproject.toml | 3 + .../test_LLM_user_context_aggregator.py | 276 ------------- tests/__init__.py | 0 tests/processors/__init__.py | 0 tests/processors/aggregators/__init__.py | 0 .../aggregators/test_llm_response.py | 370 ++++++++++++++++++ tests/processors/frameworks/__init__.py | 0 .../frameworks}/test_langchain.py | 0 tests/services/__init__.py | 0 tests/{ => services}/test_ai_services.py | 0 tests/{ => skipped}/test_aggregators.py | 0 .../test_daily_transport_service.py | 0 tests/{ => skipped}/test_openai_tts.py | 0 tests/{ => skipped}/test_pipeline.py | 0 .../{ => skipped}/test_protobuf_serializer.py | 0 .../{ => skipped}/test_websocket_transport.py | 0 tests/utils.py | 96 +++++ 18 files changed, 470 insertions(+), 277 deletions(-) delete mode 100644 src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py create mode 100644 tests/__init__.py create mode 100644 tests/processors/__init__.py create mode 100644 tests/processors/aggregators/__init__.py create mode 100644 tests/processors/aggregators/test_llm_response.py create mode 100644 tests/processors/frameworks/__init__.py rename tests/{ => processors/frameworks}/test_langchain.py (100%) create mode 100644 tests/services/__init__.py rename tests/{ => services}/test_ai_services.py (100%) rename tests/{ => skipped}/test_aggregators.py (100%) rename tests/{ => skipped}/test_daily_transport_service.py (100%) rename tests/{ => skipped}/test_openai_tts.py (100%) rename tests/{ => skipped}/test_pipeline.py (100%) rename tests/{ => skipped}/test_protobuf_serializer.py (100%) rename tests/{ => skipped}/test_websocket_transport.py (100%) create mode 100644 tests/utils.py diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b806efad4..628f7369b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -49,4 +49,4 @@ jobs: - name: Test with pytest run: | source .venv/bin/activate - pytest --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests + pytest diff --git a/pyproject.toml b/pyproject.toml index 1d8b84679..a67c9a8e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,10 @@ whisper = [ "faster-whisper~=1.1.0" ] where = ["src"] [tool.pytest.ini_options] +addopts = "--verbose --disable-warnings" +testpaths = ["tests"] pythonpath = ["src"] +asyncio_default_fixture_loop_scope = "function" [tool.setuptools_scm] local_scheme = "no-local-version" diff --git a/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py b/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py deleted file mode 100644 index 3498c1e58..000000000 --- a/src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py +++ /dev/null @@ -1,276 +0,0 @@ -# tests/test_custom_user_context.py - -"""Tests for CustomLLMUserContextAggregator""" - -import asyncio -import unittest - -from dataclasses import dataclass -from typing import List - -from pipecat.clocks.system_clock import SystemClock -from pipecat.frames.frames import ( - ControlFrame, - Frame, - StartFrame, - TranscriptionFrame, - InterimTranscriptionFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, -) -from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor - -# Note that UserStartedSpeakingFrame always come with StartInterruptionFrame -# and UserStoppedSpeakingFrame always come with StopInterruptionFrame -# S E -> None -# S T E -> T -# S I T E -> T -# S I E T -> T -# S I E I T -> T -# S E T -> T -# S E I T -> T -# S T1 I E S T2 E -> (T1 T2) -# S I E T1 I T2 -> T1 Interruption T2 -# S T1 E T2 -> T1 Interruption T2 -# S E T1 B T2 -> T1 Bot Interruption T2 -# S E T1 T2 -> T1 Interruption T2 - - -@dataclass -class EndTestFrame(ControlFrame): - pass - - -class QueuedFrameProcessor(FrameProcessor): - def __init__(self, queue: asyncio.Queue, ignore_start: bool = True): - super().__init__() - self._queue = queue - self._ignore_start = ignore_start - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if self._ignore_start and isinstance(frame, StartFrame): - return - await self._queue.put(frame) - - -async def make_test( - frames_to_send: List[Frame], expected_returned_frames: List[type] -) -> List[Frame]: - context_aggregator = LLMUserContextAggregator( - OpenAILLMContext(messages=[{"role": "", "content": ""}]) - ) - - received = asyncio.Queue() - test_processor = QueuedFrameProcessor(received) - context_aggregator.link(test_processor) - - await context_aggregator.queue_frame(StartFrame(clock=SystemClock())) - for frame in frames_to_send: - await context_aggregator.process_frame(frame, direction=FrameDirection.DOWNSTREAM) - await context_aggregator.queue_frame(EndTestFrame()) - - received_frames: List[Frame] = [] - running = True - while running: - frame = await received.get() - running = not isinstance(frame, EndTestFrame) - if running: - received_frames.append(frame) - - assert len(received_frames) == len(expected_returned_frames) - for real, expected in zip(received_frames, expected_returned_frames): - assert isinstance(real, expected) - return received_frames - - -class TestFrameProcessing(unittest.IsolatedAsyncioTestCase): - # S E -> - async def test_s_e(self): - """S E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S T E -> T - async def test_s_t_e(self): - """S T E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - TranscriptionFrame("Hello", "", ""), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S I T E -> T - async def test_s_i_t_e(self): - """S I T E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - TranscriptionFrame("This is a test", "", ""), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S I E T -> T - async def test_s_i_e_t(self): - """S I E T case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - UserStoppedSpeakingFrame(), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S I E I T -> T - async def test_s_i_e_i_t(self): - """S I E I T case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - UserStoppedSpeakingFrame(), - InterimTranscriptionFrame("This is", "", ""), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S E T -> T - async def test_s_e_t(self): - """S E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S E I T -> T - async def test_s_e_i_t(self): - """S E I T case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S T1 I E S T2 E -> "T1 T2" - async def test_s_t1_i_e_s_t2_e(self): - """S T1 I E S T2 E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - TranscriptionFrame("T1", "", ""), - InterimTranscriptionFrame("", "", ""), - UserStoppedSpeakingFrame(), - UserStartedSpeakingFrame(), - TranscriptionFrame("T2", "", ""), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - result = await make_test(frames_to_send, expected_returned_frames) - assert result[-1].context.messages[-1]["content"] == "T1 T2" - - # S I E T1 I T2 -> T1 Interruption T2 - async def test_s_i_e_t1_i_t2(self): - """S I E T1 I T2 case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("", "", ""), - UserStoppedSpeakingFrame(), - TranscriptionFrame("T1", "", ""), - InterimTranscriptionFrame("", "", ""), - TranscriptionFrame("T2", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - OpenAILLMContextFrame, - ] - result = await make_test(frames_to_send, expected_returned_frames) - assert result[-2].context.messages[-2]["content"] == "T1" - assert result[-1].context.messages[-1]["content"] == "T2" - - # # S T1 E T2 -> T1 Interruption T2 - # async def test_s_t1_e_t2(self): - # """S T1 E T2 case""" - # frames_to_send = [ - # UserStartedSpeakingFrame(), - # TranscriptionFrame("T1", "", ""), - # UserStoppedSpeakingFrame(), - # TranscriptionFrame("T2", "", ""), - # ] - # expected_returned_frames = [ - # UserStartedSpeakingFrame, - # UserStoppedSpeakingFrame, - # OpenAILLMContextFrame, - # OpenAILLMContextFrame, - # ] - # result = await make_test(frames_to_send, expected_returned_frames) - # assert result[-1].context.messages[-1]["content"] == " T1 T2" - - # # S E T1 T2 -> T1 Interruption T2 - # async def test_s_e_t1_t2(self): - # """S E T1 T2 case""" - # frames_to_send = [ - # UserStartedSpeakingFrame(), - # UserStoppedSpeakingFrame(), - # TranscriptionFrame("T1", "", ""), - # TranscriptionFrame("T2", "", ""), - # ] - # expected_returned_frames = [ - # UserStartedSpeakingFrame, - # UserStoppedSpeakingFrame, - # OpenAILLMContextFrame, - # OpenAILLMContextFrame, - # ] - # result = await make_test(frames_to_send, expected_returned_frames) - # assert result[-1].context.messages[-1]["content"] == " T1 T2" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/__init__.py b/tests/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/aggregators/__init__.py b/tests/processors/aggregators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/aggregators/test_llm_response.py b/tests/processors/aggregators/test_llm_response.py new file mode 100644 index 000000000..ba161329b --- /dev/null +++ b/tests/processors/aggregators/test_llm_response.py @@ -0,0 +1,370 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.frames.frames import ( + BotInterruptionFrame, + InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + StartInterruptionFrame, + StopInterruptionFrame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMFullResponseAggregator, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from tests.utils import run_test + + +class TestLLMUserContextAggregator(unittest.IsolatedAsyncioTestCase): + # S E -> + async def test_s_e(self): + """S E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("Hello", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I T E -> T + async def test_s_i_t_e(self): + """S I T E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I E T -> T + async def test_s_i_e_t(self): + """S I E T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I E I T -> T + async def test_s_i_e_i_t(self): + """S I E I T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This is", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S E T -> T + async def test_s_e_t(self): + """S E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S E I T -> T + async def test_s_e_i_t(self): + """S E I T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S T1 I E S T2 E -> "T1 T2" + async def test_s_t1_i_e_s_t2_e(self): + """S T1 I E S T2 E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T2", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_returned_frames + ) + assert received_down[-1].context.messages[-1]["content"] == "T1 T2" + + # S I E T1 I T2 -> T1 Interruption T2 + async def test_s_i_e_t1_i_t2(self): + """S I E T1 I T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + # S T1 E T2 -> T1 Interruption T2 + async def test_s_t1_e_t2(self): + """S T1 E T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + # S E T1 T2 -> T1 Interruption T2 + async def test_s_e_t1_t2(self): + """S E T1 T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + +class TestLLMAssistantContextAggregator(unittest.IsolatedAsyncioTestCase): + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + context_aggregator = LLMAssistantContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame("Hello this is Pipecat speaking!"), + TextFrame("How are you?"), + LLMFullResponseEndFrame(), + ] + expected_returned_frames = [ + LLMFullResponseStartFrame, + OpenAILLMContextFrame, + LLMFullResponseEndFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_returned_frames + ) + assert ( + received_down[-2].context.messages[-1]["content"] + == "Hello this is Pipecat speaking! How are you?" + ) + + +class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase): + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + response_aggregator = LLMFullResponseAggregator() + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame("Hello "), + TextFrame("this "), + TextFrame("is "), + TextFrame("Pipecat!"), + LLMFullResponseEndFrame(), + ] + expected_returned_frames = [ + LLMFullResponseStartFrame, + TextFrame, + LLMFullResponseEndFrame, + ] + (received_down, _) = await run_test( + response_aggregator, frames_to_send, expected_returned_frames + ) + assert received_down[-2].text == "Hello this is Pipecat!" diff --git a/tests/processors/frameworks/__init__.py b/tests/processors/frameworks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_langchain.py b/tests/processors/frameworks/test_langchain.py similarity index 100% rename from tests/test_langchain.py rename to tests/processors/frameworks/test_langchain.py diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_ai_services.py b/tests/services/test_ai_services.py similarity index 100% rename from tests/test_ai_services.py rename to tests/services/test_ai_services.py diff --git a/tests/test_aggregators.py b/tests/skipped/test_aggregators.py similarity index 100% rename from tests/test_aggregators.py rename to tests/skipped/test_aggregators.py diff --git a/tests/test_daily_transport_service.py b/tests/skipped/test_daily_transport_service.py similarity index 100% rename from tests/test_daily_transport_service.py rename to tests/skipped/test_daily_transport_service.py diff --git a/tests/test_openai_tts.py b/tests/skipped/test_openai_tts.py similarity index 100% rename from tests/test_openai_tts.py rename to tests/skipped/test_openai_tts.py diff --git a/tests/test_pipeline.py b/tests/skipped/test_pipeline.py similarity index 100% rename from tests/test_pipeline.py rename to tests/skipped/test_pipeline.py diff --git a/tests/test_protobuf_serializer.py b/tests/skipped/test_protobuf_serializer.py similarity index 100% rename from tests/test_protobuf_serializer.py rename to tests/skipped/test_protobuf_serializer.py diff --git a/tests/test_websocket_transport.py b/tests/skipped/test_websocket_transport.py similarity index 100% rename from tests/test_websocket_transport.py rename to tests/skipped/test_websocket_transport.py diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..cb3df0c8c --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from dataclasses import dataclass +from typing import List, Tuple + +from pipecat.clocks.system_clock import SystemClock +from pipecat.frames.frames import ( + ControlFrame, + Frame, + StartFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +@dataclass +class EndTestFrame(ControlFrame): + pass + + +class QueuedFrameProcessor(FrameProcessor): + def __init__(self, queue: asyncio.Queue, ignore_start: bool = True): + super().__init__() + self._queue = queue + self._ignore_start = ignore_start + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if self._ignore_start and isinstance(frame, StartFrame): + return + await self._queue.put(frame) + + +async def run_test( + processor: FrameProcessor, + frames_to_send: List[Frame], + expected_down_frames: List[type], + expected_up_frames: List[type] = [], +) -> Tuple[List[Frame], List[Frame]]: + received_up = asyncio.Queue() + received_down = asyncio.Queue() + up_processor = QueuedFrameProcessor(received_up) + down_processor = QueuedFrameProcessor(received_down) + + up_processor.link(processor) + processor.link(down_processor) + + await processor.queue_frame(StartFrame(clock=SystemClock())) + + for frame in frames_to_send: + await processor.process_frame(frame, FrameDirection.DOWNSTREAM) + + await processor.queue_frame(EndTestFrame()) + await processor.queue_frame(EndTestFrame(), FrameDirection.UPSTREAM) + + # + # Down frames + # + received_down_frames: List[Frame] = [] + running = True + while running: + frame = await received_down.get() + running = not isinstance(frame, EndTestFrame) + if running: + received_down_frames.append(frame) + + print("received DOWN frames =", received_down_frames) + + assert len(received_down_frames) == len(expected_down_frames) + + for real, expected in zip(received_down_frames, expected_down_frames): + assert isinstance(real, expected) + + # + # Up frames + # + received_up_frames: List[Frame] = [] + running = True + while running: + frame = await received_up.get() + running = not isinstance(frame, EndTestFrame) + if running: + received_up_frames.append(frame) + + print("received UP frames =", received_up_frames) + + assert len(received_up_frames) == len(expected_up_frames) + + for real, expected in zip(received_up_frames, expected_up_frames): + assert isinstance(real, expected) + + return (received_down_frames, received_up_frames) From eef3f320b14493fef19d0cc1c721826bc7f5f567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 23 Dec 2024 18:03:01 -0800 Subject: [PATCH 5/5] update README with KoalaFilter --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2821e8f32..f3f3ff2e0 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ Available options include: | Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), WebSocket, Local | `pip install "pipecat-ai[daily]"` | | Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` | | Vision & Image | [Moondream](https://docs.pipecat.ai/server/services/vision/moondream), [fal](https://docs.pipecat.ai/server/services/image-generation/fal) | `pip install "pipecat-ai[moondream]"` | -| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` | +| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter) | `pip install "pipecat-ai[silero]"` | | Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` | 📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)