diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index ca7baca72..ecbab409a 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -40,6 +40,20 @@ WillSynthesizeAssistantReply = BeforeLLMCallback +AfterLLMCallback = Callable[ + ["VoicePipelineAgent", + Union[ + AsyncGenerator[str, None], + str, + AsyncIterable[str]] + ], + Union[ + AsyncGenerator[str, None], + str, + AsyncIterable[str] + ], +] + BeforeTTSCallback = Callable[ ["VoicePipelineAgent", Union[str, AsyncIterable[str]]], SpeechSource, @@ -104,6 +118,10 @@ class SpeechData: SpeechDataContextVar = contextvars.ContextVar[SpeechData]("voice_assistant_speech_data") +def _default_after_llm_cb( + agent: VoicePipelineAgent, text: AsyncIterable[str, None] | str | AsyncIterable[str] +) -> AsyncIterable[str, None] | str | AsyncIterable[str]: + return text def _default_before_tts_cb( agent: VoicePipelineAgent, text: str | AsyncIterable[str] @@ -120,6 +138,7 @@ class _ImplOptions: max_nested_fnc_calls: int preemptive_synthesis: bool before_llm_cb: BeforeLLMCallback + after_llm_cb: AfterLLMCallback before_tts_cb: BeforeTTSCallback plotting: bool transcription: AgentTranscriptionOptions @@ -172,6 +191,7 @@ def __init__( preemptive_synthesis: bool = False, transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), before_llm_cb: BeforeLLMCallback = _default_before_llm_cb, + after_llm_cb: AfterLLMCallback = _default_after_llm_cb, before_tts_cb: BeforeTTSCallback = _default_before_tts_cb, plotting: bool = False, loop: asyncio.AbstractEventLoop | None = None, @@ -204,6 +224,10 @@ def __init__( stream by calling the llm.chat() method. Returning False will cancel the synthesis of the reply. + after_llm_cb: Callback called after the LLM stream is generated. + It can be used to modify the LLM stream text (e.g. editing the response). + + Modify the LLM stream and return it to alter the synthesized reply. before_tts_cb: Callback called when the assistant is about to synthesize a speech. This can be used to customize text before the speech synthesis. (e.g: editing the pronunciation of a word). @@ -229,6 +253,7 @@ def __init__( preemptive_synthesis=preemptive_synthesis, transcription=transcription, before_llm_cb=before_llm_cb, + after_llm_cb=after_llm_cb, before_tts_cb=before_tts_cb, ) self._plotter = AssistantPlotter(self._loop) @@ -881,6 +906,8 @@ async def _llm_stream_to_str_generator( if isinstance(source, LLMStream): source = _llm_stream_to_str_generator(source) + source = self._opts.after_llm_cb(self, source) + og_source = source transcript_source = source if isinstance(og_source, AsyncIterable):