diff --git a/.changeset/clever-lies-explode.md b/.changeset/clever-lies-explode.md new file mode 100644 index 000000000..1bf7ea69d --- /dev/null +++ b/.changeset/clever-lies-explode.md @@ -0,0 +1,7 @@ +--- +"livekit-plugins-anthropic": patch +"livekit-plugins-openai": patch +"livekit-agents": patch +--- + +Moved create_ai_function_info to function_context.py for better reusability and reduce repetation diff --git a/livekit-agents/livekit/agents/llm/__init__.py b/livekit-agents/livekit/agents/llm/__init__.py index acc5b0ce6..d3a06f520 100644 --- a/livekit-agents/livekit/agents/llm/__init__.py +++ b/livekit-agents/livekit/agents/llm/__init__.py @@ -15,6 +15,7 @@ FunctionContext, FunctionInfo, TypeInfo, + _create_ai_function_info, ai_callable, ) from .llm import ( @@ -54,4 +55,5 @@ "FallbackAdapter", "AvailabilityChangedEvent", "ToolChoice", + "_create_ai_function_info", ] diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index aa4df9842..4470492fe 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -18,6 +18,7 @@ import enum import functools import inspect +import json import types import typing from dataclasses import dataclass @@ -303,3 +304,96 @@ def _is_optional_type(typ) -> Tuple[bool, Any]: return True, non_none_args[0] return False, None + + +def _create_ai_function_info( + fnc_ctx: FunctionContext, + tool_call_id: str, + fnc_name: str, + raw_arguments: str, # JSON string +) -> FunctionCallInfo: + if fnc_name not in fnc_ctx.ai_functions: + raise ValueError(f"AI function {fnc_name} not found") + + parsed_arguments: dict[str, Any] = {} + try: + if raw_arguments: # ignore empty string + parsed_arguments = json.loads(raw_arguments) + except json.JSONDecodeError: + raise ValueError( + f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}" + ) + + fnc_info = fnc_ctx.ai_functions[fnc_name] + + # Ensure all necessary arguments are present and of the correct type. + sanitized_arguments: dict[str, Any] = {} + for arg_info in fnc_info.arguments.values(): + if arg_info.name not in parsed_arguments: + if arg_info.default is inspect.Parameter.empty: + raise ValueError( + f"AI function {fnc_name} missing required argument {arg_info.name}" + ) + continue + + arg_value = parsed_arguments[arg_info.name] + is_optional, inner_th = _is_optional_type(arg_info.type) + + if typing.get_origin(inner_th) is not None: + if not isinstance(arg_value, list): + raise ValueError( + f"AI function {fnc_name} argument {arg_info.name} should be a list" + ) + + inner_type = typing.get_args(inner_th)[0] + sanitized_value = [ + _sanitize_primitive( + value=v, + expected_type=inner_type, + choices=arg_info.choices, + ) + for v in arg_value + ] + else: + sanitized_value = _sanitize_primitive( + value=arg_value, + expected_type=inner_th, + choices=arg_info.choices, + ) + + sanitized_arguments[arg_info.name] = sanitized_value + + return FunctionCallInfo( + tool_call_id=tool_call_id, + raw_arguments=raw_arguments, + function_info=fnc_info, + arguments=sanitized_arguments, + ) + + +def _sanitize_primitive( + *, value: Any, expected_type: type, choices: tuple | None +) -> Any: + if expected_type is str: + if not isinstance(value, str): + raise ValueError(f"expected str, got {type(value)}") + elif expected_type in (int, float): + if not isinstance(value, (int, float)): + raise ValueError(f"expected number, got {type(value)}") + + if expected_type is int: + if value % 1 != 0: + raise ValueError("expected int, got float") + + value = int(value) + elif expected_type is float: + value = float(value) + + elif expected_type is bool: + if not isinstance(value, bool): + raise ValueError(f"expected bool, got {type(value)}") + + if choices and value not in choices: + raise ValueError(f"invalid value {value}, not in {choices}") + + return value diff --git a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py index 9678c9381..69b468d23 100644 --- a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py +++ b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py @@ -24,7 +24,6 @@ Awaitable, List, Literal, - Tuple, Union, cast, get_args, @@ -41,7 +40,10 @@ utils, ) from livekit.agents.llm import ToolChoice -from livekit.agents.llm.function_context import _is_optional_type +from livekit.agents.llm.function_context import ( + _create_ai_function_info, + _is_optional_type, +) from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions import anthropic @@ -487,67 +489,6 @@ def _build_anthropic_image_content( ) -def _create_ai_function_info( - fnc_ctx: llm.function_context.FunctionContext, - tool_call_id: str, - fnc_name: str, - raw_arguments: str, # JSON string -) -> llm.function_context.FunctionCallInfo: - if fnc_name not in fnc_ctx.ai_functions: - raise ValueError(f"AI function {fnc_name} not found") - - parsed_arguments: dict[str, Any] = {} - try: - if raw_arguments: # ignore empty string - parsed_arguments = json.loads(raw_arguments) - except json.JSONDecodeError: - raise ValueError( - f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}" - ) - - fnc_info = fnc_ctx.ai_functions[fnc_name] - - # Ensure all necessary arguments are present and of the correct type. - sanitized_arguments: dict[str, Any] = {} - for arg_info in fnc_info.arguments.values(): - if arg_info.name not in parsed_arguments: - if arg_info.default is inspect.Parameter.empty: - raise ValueError( - f"AI function {fnc_name} missing required argument {arg_info.name}" - ) - continue - - arg_value = parsed_arguments[arg_info.name] - is_optional, inner_th = _is_optional_type(arg_info.type) - - if get_origin(inner_th) is not None: - if not isinstance(arg_value, list): - raise ValueError( - f"AI function {fnc_name} argument {arg_info.name} should be a list" - ) - - inner_type = get_args(inner_th)[0] - sanitized_value = [ - _sanitize_primitive( - value=v, expected_type=inner_type, choices=arg_info.choices - ) - for v in arg_value - ] - else: - sanitized_value = _sanitize_primitive( - value=arg_value, expected_type=inner_th, choices=arg_info.choices - ) - - sanitized_arguments[arg_info.name] = sanitized_value - - return llm.function_context.FunctionCallInfo( - tool_call_id=tool_call_id, - raw_arguments=raw_arguments, - function_info=fnc_info, - arguments=sanitized_arguments, - ) - - def _build_function_description( fnc_info: llm.function_context.FunctionInfo, ) -> anthropic.types.ToolParam: @@ -598,31 +539,3 @@ def type2str(t: type) -> str: "description": fnc_info.description, "input_schema": input_schema, } - - -def _sanitize_primitive( - *, value: Any, expected_type: type, choices: Tuple[Any] | None -) -> Any: - if expected_type is str: - if not isinstance(value, str): - raise ValueError(f"expected str, got {type(value)}") - elif expected_type in (int, float): - if not isinstance(value, (int, float)): - raise ValueError(f"expected number, got {type(value)}") - - if expected_type is int: - if value % 1 != 0: - raise ValueError("expected int, got float") - - value = int(value) - elif expected_type is float: - value = float(value) - - elif expected_type is bool: - if not isinstance(value, bool): - raise ValueError(f"expected bool, got {type(value)}") - - if choices and value not in choices: - raise ValueError(f"invalid value {value}, not in {choices}") - - return value diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index a87eaf542..acef65b6a 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -289,6 +289,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): except Exception: logger.exception("failed to process AssemblyAI message") + ws: aiohttp.ClientWebSocketResponse | None = None + while True: try: ws = await self._connect_ws() diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py index 8bf05a19f..8dbc3a33e 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py @@ -15,79 +15,13 @@ from __future__ import annotations import inspect -import json import typing from typing import Any from livekit.agents.llm import function_context, llm from livekit.agents.llm.function_context import _is_optional_type -__all__ = ["build_oai_function_description", "create_ai_function_info"] - - -def create_ai_function_info( - fnc_ctx: function_context.FunctionContext, - tool_call_id: str, - fnc_name: str, - raw_arguments: str, # JSON string -) -> function_context.FunctionCallInfo: - if fnc_name not in fnc_ctx.ai_functions: - raise ValueError(f"AI function {fnc_name} not found") - - parsed_arguments: dict[str, Any] = {} - try: - if raw_arguments: # ignore empty string - parsed_arguments = json.loads(raw_arguments) - except json.JSONDecodeError: - raise ValueError( - f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}" - ) - - fnc_info = fnc_ctx.ai_functions[fnc_name] - - # Ensure all necessary arguments are present and of the correct type. - sanitized_arguments: dict[str, Any] = {} - for arg_info in fnc_info.arguments.values(): - if arg_info.name not in parsed_arguments: - if arg_info.default is inspect.Parameter.empty: - raise ValueError( - f"AI function {fnc_name} missing required argument {arg_info.name}" - ) - continue - - arg_value = parsed_arguments[arg_info.name] - is_optional, inner_th = _is_optional_type(arg_info.type) - - if typing.get_origin(inner_th) is not None: - if not isinstance(arg_value, list): - raise ValueError( - f"AI function {fnc_name} argument {arg_info.name} should be a list" - ) - - inner_type = typing.get_args(inner_th)[0] - sanitized_value = [ - _sanitize_primitive( - value=v, - expected_type=inner_type, - choices=arg_info.choices, - ) - for v in arg_value - ] - else: - sanitized_value = _sanitize_primitive( - value=arg_value, - expected_type=inner_th, - choices=arg_info.choices, - ) - - sanitized_arguments[arg_info.name] = sanitized_value - - return function_context.FunctionCallInfo( - tool_call_id=tool_call_id, - raw_arguments=raw_arguments, - function_info=fnc_info, - arguments=sanitized_arguments, - ) +__all__ = ["build_oai_function_description"] def build_oai_function_description( @@ -156,31 +90,3 @@ def type2str(t: type) -> str: }, }, } - - -def _sanitize_primitive( - *, value: Any, expected_type: type, choices: tuple | None -) -> Any: - if expected_type is str: - if not isinstance(value, str): - raise ValueError(f"expected str, got {type(value)}") - elif expected_type in (int, float): - if not isinstance(value, (int, float)): - raise ValueError(f"expected number, got {type(value)}") - - if expected_type is int: - if value % 1 != 0: - raise ValueError("expected int, got float") - - value = int(value) - elif expected_type is float: - value = float(value) - - elif expected_type is bool: - if not isinstance(value, bool): - raise ValueError(f"expected bool, got {type(value)}") - - if choices and value not in choices: - raise ValueError(f"invalid value {value}, not in {choices}") - - return value diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 7dfbaff24..bcff2cfa9 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -29,17 +29,14 @@ APITimeoutError, llm, ) -from livekit.agents.llm import ToolChoice +from livekit.agents.llm import ToolChoice, _create_ai_function_info from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions import openai from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam from openai.types.chat.chat_completion_chunk import Choice -from ._oai_api import ( - build_oai_function_description, - create_ai_function_info, -) +from ._oai_api import build_oai_function_description from .log import logger from .models import ( CerebrasChatModels, @@ -840,7 +837,7 @@ def _try_build_function(self, id: str, choice: Choice) -> llm.ChatChunk | None: ) return None - fnc_info = create_ai_function_info( + fnc_info = _create_ai_function_info( self._fnc_ctx, self._tool_call_id, self._fnc_name, self._fnc_raw_arguments ) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 04bf14ac5..26bc2649b 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -12,10 +12,11 @@ import aiohttp from livekit import rtc from livekit.agents import llm, utils +from livekit.agents.llm.function_context import _create_ai_function_info from livekit.agents.metrics import MultimodalLLMError, MultimodalLLMMetrics from typing_extensions import TypedDict -from .._oai_api import build_oai_function_description, create_ai_function_info +from .._oai_api import build_oai_function_description from . import api_proto, remote_items from .log import logger @@ -1521,7 +1522,7 @@ def _handle_response_output_item_done( item = response_output_done["item"] assert item["type"] == "function_call" - fnc_call_info = create_ai_function_info( + fnc_call_info = _create_ai_function_info( self._fnc_ctx, item["call_id"], item["name"],