Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM user frame processor with tests #703

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ jobs:
- name: Test with pytest
run: |
source .venv/bin/activate
pytest --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests
pytest
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed LLM response aggregators to support more uses cases such as delayed
transcriptions.

- Fixed an issue that could cause the bot to stop talking if there was a user
interruption before getting any audio from the TTS service.

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Available options include:
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), WebSocket, Local | `pip install "pipecat-ai[daily]"` |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` |
| Vision & Image | [Moondream](https://docs.pipecat.ai/server/services/vision/moondream), [fal](https://docs.pipecat.ai/server/services/image-generation/fal) | `pip install "pipecat-ai[moondream]"` |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter) | `pip install "pipecat-ai[silero]"` |
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` |

📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ whisper = [ "faster-whisper~=1.1.0" ]
where = ["src"]

[tool.pytest.ini_options]
addopts = "--verbose --disable-warnings"
testpaths = ["tests"]
pythonpath = ["src"]
asyncio_default_fixture_loop_scope = "function"

[tool.setuptools_scm]
local_scheme = "no-local-version"
Expand Down
228 changes: 129 additions & 99 deletions src/pipecat/processors/aggregators/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Type

from pipecat.frames.frames import (
BotInterruptionFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(
interim_accumulator_frame: Type[TextFrame] | None = None,
handle_interruptions: bool = False,
expect_stripped_words: bool = True, # if True, need to add spaces between words
interrupt_double_accumulator: bool = True, # if True, interrupt if two or more accumulators are received
):
super().__init__()

Expand All @@ -51,8 +53,8 @@ def __init__(
self._interim_accumulator_frame = interim_accumulator_frame
self._handle_interruptions = handle_interruptions
self._expect_stripped_words = expect_stripped_words
self._interrupt_double_accumulator = interrupt_double_accumulator

# Reset our accumulator state.
self._reset()

@property
Expand All @@ -69,33 +71,30 @@ def role(self):

# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
# S: Start, E: End, T: Transcription, I: Interim
#
# S E -> None
# S T E -> X
# S I T E -> X
# S I E T -> X
# S I E I T -> X
# S E T -> X
# S E I T -> X
#
# The following case would not be supported:
#
# S I E T1 I T2 -> X
#
# and T2 would be dropped.
# S E -> None -> User started speaking but no transcription.
# S T E -> T -> Transcription between user started and stopped speaking.
# S E T -> T -> Transcription after user stopped speaking.
# S I T E -> T -> Transcription between user started and stopped speaking (with interims).
# S I E T -> T -> Transcription after user stopped speaking (with interims).
# S I E I T -> T -> Transcription after user stopped speaking (with interims).
# S E I T -> T -> Transcription after user stopped speaking (with interims).
# S T1 I E S T2 E -> "T1 T2" -> Merge two transcriptions if we got a first interim.
# S I E T1 I T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
# S T1 E T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
# S E T1 B T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.
# S E T1 T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription.

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

send_aggregation = False

if isinstance(frame, self._start_frame):
self._aggregation = ""
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
self._seen_interim_results = False
await self.push_frame(frame, direction)
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
Expand All @@ -109,23 +108,36 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self._aggregating
await self.push_frame(frame, direction)
elif isinstance(frame, self._accumulator_frame):
if self._aggregating:
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
# We have recevied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
send_aggregation = self._seen_end_frame
if (
self._interrupt_double_accumulator
and self._sent_aggregation_after_last_interruption
):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
self._sent_aggregation_after_last_interruption = False

if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text

# If we haven't seen the start frame but we got an accumulator frame
# it means two things: it was develiver before the end frame or it
# was delivered late. In both cases so we want to send the
# aggregation.
send_aggregation = not self._seen_start_frame

# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
if (
self._interrupt_double_accumulator
and self._sent_aggregation_after_last_interruption
):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
self._sent_aggregation_after_last_interruption = False
self._seen_interim_results = True
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
elif isinstance(frame, StartInterruptionFrame) and self._handle_interruptions:
await self._push_aggregation()
# Reset anyways
self._reset()
Expand All @@ -142,6 +154,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
if send_aggregation:
await self._push_aggregation()

if isinstance(frame, self._end_frame):
await self.push_frame(frame, direction)

async def _push_aggregation(self):
if len(self._aggregation) > 0:
self._messages.append({"role": self._role, "content": self._aggregation})
Expand All @@ -150,6 +165,8 @@ async def _push_aggregation(self):
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""

self._sent_aggregation_after_last_interruption = True

frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)

Expand All @@ -172,84 +189,33 @@ def _reset(self):
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False


class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
super().__init__(
messages=messages,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
)
self._sent_aggregation_after_last_interruption = False


class LLMUserResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(
messages=messages,
role="user",
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
**kwargs,
)


class LLMFullResponseAggregator(FrameProcessor):
"""This class aggregates Text frames until it receives a
LLMFullResponseEndFrame, then emits the concatenated text as
a single text frame.

given the following frames:

TextFrame("Hello,")
TextFrame(" world.")
TextFrame(" I am")
TextFrame(" an LLM.")
LLMFullResponseEndFrame()]

this processor will yield nothing for the first 4 frames, then

TextFrame("Hello, world. I am an LLM.")
LLMFullResponseEndFrame()

when passed the last frame.

>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
... else:
... print(frame.__class__.__name__)

>>> aggregator = LLMFullResponseAggregator()
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
Hello, world. I am an LLM.
LLMFullResponseEndFrame
"""

def __init__(self):
super().__init__()
self._aggregation = ""

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, TextFrame):
self._aggregation += frame.text
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_frame(TextFrame(self._aggregation))
await self.push_frame(frame)
self._aggregation = ""
else:
await self.push_frame(frame, direction)
class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(
messages=messages,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
**kwargs,
)


class LLMContextAggregator(LLMResponseAggregator):
Expand Down Expand Up @@ -286,15 +252,14 @@ async def _push_aggregation(self):
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""

self._sent_aggregation_after_last_interruption = True

frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)

# Reset our accumulator state.
self._reset()


class LLMAssistantContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(
messages=[],
context=context,
Expand All @@ -303,12 +268,12 @@ def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = T
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=expect_stripped_words,
**kwargs,
)


class LLMUserContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext):
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(
messages=[],
context=context,
Expand All @@ -317,4 +282,69 @@ def __init__(self, context: OpenAILLMContext):
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
**kwargs,
)


class LLMFullResponseAggregator(FrameProcessor):
"""This class aggregates Text frames between LLMFullResponseStartFrame and
LLMFullResponseEndFrame, then emits the concatenated text as a single text
frame.

given the following frames:

LLMFullResponseStartFrame()
TextFrame("Hello,")
TextFrame(" world.")
TextFrame(" I am")
TextFrame(" an LLM.")
LLMFullResponseEndFrame()

this processor will push,

LLMFullResponseStartFrame()
TextFrame("Hello, world. I am an LLM.")
LLMFullResponseEndFrame()

when passed the last frame.

>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
... else:
... print(frame.__class__.__name__)

>>> aggregator = LLMFullResponseAggregator()
>>> asyncio.run(print_frames(aggregator, LLMFullResponseStartFrame()))
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
LLMFullResponseStartFrame
Hello, world. I am an LLM.
LLMFullResponseEndFrame

"""

def __init__(self):
super().__init__()
self._aggregation = ""
self._seen_start_frame = False

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, LLMFullResponseStartFrame):
self._seen_start_frame = True
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseEndFrame):
self._seen_start_frame = False
await self.push_frame(TextFrame(self._aggregation))
await self.push_frame(frame)
self._aggregation = ""
elif isinstance(frame, TextFrame) and self._seen_start_frame:
self._aggregation += frame.text
else:
await self.push_frame(frame, direction)
Loading
Loading