Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support function pre-execute system prompt #714

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/pipecat/processors/aggregators/openai_llm_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,18 @@ def messages(self) -> List[ChatCompletionMessageParam]:
def tools(self) -> List[ChatCompletionToolParam] | NotGiven:
return self._tools

@tools.setter
def tools(self, value):
self._tools = value

@property
def tool_choice(self) -> ChatCompletionToolChoiceOptionParam | NotGiven:
return self._tool_choice

@tool_choice.setter
def tool_choice(self, value):
self._tool_choice = value

def add_message(self, message: ChatCompletionMessageParam):
self._messages.append(message)

Expand Down
15 changes: 13 additions & 2 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,31 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self._callbacks = {}
self._start_callbacks = {}
self._pre_execute_prompts = {}

# TODO-CB: callback function type
def register_function(self, function_name: str | None, callback, start_callback=None):
def register_function(
self,
function_name: str | None,
callback,
start_callback=None,
pre_execute_prompt: str | None = None
):
# Registering a function with the function_name set to None will run that callback
# for all functions
self._callbacks[function_name] = callback
# QUESTION FOR CB: maybe this isn't needed anymore?
if start_callback:
self._start_callbacks[function_name] = start_callback
if pre_execute_prompt:
self._pre_execute_prompts[function_name] = pre_execute_prompt

def unregister_function(self, function_name: str | None):
del self._callbacks[function_name]
if self._start_callbacks[function_name]:
if function_name in self._start_callbacks:
del self._start_callbacks[function_name]
if function_name in self._pre_execute_prompts:
del self._pre_execute_prompts[function_name]

def has_function(self, function_name: str):
if None in self._callbacks.keys():
Expand Down
29 changes: 29 additions & 0 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,34 @@ async def _stream_chat_completions(
chunks = await self.get_chat_completions(context, messages)

return chunks

async def _handle_pre_execute_prompt(self, context: OpenAILLMContext, function_name: str):
"""Handle pre-execute prompt for a function if one exists."""
pre_execute_prompt = self._pre_execute_prompts.get(function_name)
if not pre_execute_prompt:
return

logger.debug(f"Handling pre_execute_prompt for function: {function_name}")

# Add the pre-execute prompt as a system message to the context
context.add_message({"role": "system", "content": pre_execute_prompt})

# Temporarily disable function calling to prevent recursion
original_tools = context.tools
original_tool_choice = context.tool_choice
context.tools = NOT_GIVEN
context.tool_choice = NOT_GIVEN

# Process the context normally
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self._process_context(context)
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())

# Restore function calling capability
context.tools = original_tools
context.tool_choice = original_tool_choice

async def _process_context(self, context: OpenAILLMContext):
functions_list = []
Expand Down Expand Up @@ -250,6 +278,7 @@ async def _process_context(self, context: OpenAILLMContext):
if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name
tool_call_id = tool_call.id
await self._handle_pre_execute_prompt(context, function_name)
await self.call_start_function(context, function_name)
if tool_call.function and tool_call.function.arguments:
# Keep iterating through the response to collect all the argument fragments
Expand Down