From ef310ca61f8bc7da90f3886def329855f7bfc8e4 Mon Sep 17 00:00:00 2001 From: juan Date: Thu, 19 Dec 2024 14:51:53 +0100 Subject: [PATCH] removed duplications in function messages. Removed messages created inside function calls from chat message to keep the chat context strictly to what the LLM sees in the instructions --- .../livekit/agents/pipeline/pipeline_agent.py | 41 ++----------------- 1 file changed, 4 insertions(+), 37 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index a08291ea4..fd75b7b48 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -69,7 +69,6 @@ def __init__(self, assistant: "VoicePipelineAgent", llm_stream: LLMStream) -> No self._assistant = assistant self._metadata = dict[str, Any]() self._llm_stream = llm_stream - self._extra_chat_messages: list[ChatMessage] = [] @staticmethod def get_current() -> "AgentCallContext": @@ -92,15 +91,6 @@ def get_metadata(self, key: str, default: Any = None) -> Any: def llm_stream(self) -> LLMStream: return self._llm_stream - def add_extra_chat_message(self, message: ChatMessage) -> None: - """Append chat message to the end of function outputs for the answer LLM call""" - self._extra_chat_messages.append(message) - - @property - def extra_chat_messages(self) -> list[ChatMessage]: - return self._extra_chat_messages - - def _default_before_llm_cb( agent: VoicePipelineAgent, chat_ctx: ChatContext ) -> LLMStream: @@ -446,7 +436,6 @@ async def say( await self._track_published_fut call_ctx = None - fnc_source: str | AsyncIterable[str] | None = None if add_to_chat_ctx: try: call_ctx = AgentCallContext.get_current() @@ -454,14 +443,9 @@ async def say( # no active call context, ignore pass else: - if isinstance(source, LLMStream): - logger.warning( - "LLMStream will be ignored for function call chat context" - ) - elif isinstance(source, AsyncIterable): - source, fnc_source = utils.aio.itertools.tee(source, 2) # type: ignore - else: - fnc_source = source + if call_ctx is not None: + # Don't add to chat context if we're in a function call + add_to_chat_ctx = False new_handle = SpeechHandle.create_assistant_speech( allow_interruptions=allow_interruptions, add_to_chat_ctx=add_to_chat_ctx @@ -474,23 +458,6 @@ async def say( else: self._add_speech_for_playout(new_handle) - # add the speech to the function call context if needed - if call_ctx is not None and fnc_source is not None: - if isinstance(fnc_source, AsyncIterable): - text = "" - async for chunk in fnc_source: - text += chunk - else: - text = fnc_source - - call_ctx.add_extra_chat_message( - ChatMessage.create(text=text, role="assistant") - ) - logger.debug( - "added speech to function call chat context", - extra={"text": text}, - ) - return new_handle def _update_state(self, state: AgentState, delay: float = 0.0): @@ -805,6 +772,7 @@ def _commit_user_question_if_needed() -> None: collected_text and speech_handle.add_to_chat_ctx and (not user_question or speech_handle.user_committed) + and not is_using_tools ): if speech_handle.extra_tools_messages: self._chat_ctx.messages.extend(speech_handle.extra_tools_messages) @@ -919,7 +887,6 @@ async def _execute_function_calls() -> None: # synthesize the tool speech with the chat ctx from llm_stream chat_ctx = call_ctx.chat_ctx.copy() chat_ctx.messages.extend(extra_tools_messages) - chat_ctx.messages.extend(call_ctx.extra_chat_messages) answer_llm_stream = self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx) synthesis_handle = self._synthesize_agent_speech(