Skip to content

Commit

Permalink
fix: tools support issue
Browse files Browse the repository at this point in the history
  • Loading branch information
nitin4real committed Nov 22, 2024
1 parent 4aee22d commit 9f9f2ad
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 11 deletions.
24 changes: 23 additions & 1 deletion realtime_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from agora_realtime_ai_api.rtc import Channel, ChatMessage, RtcEngine, RtcOptions

from .logger import setup_logger
from .realtime.struct import InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreated, ResponseDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json
from .realtime.struct import FunctionCallOutputItemParam, InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreate, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreate, ResponseCreated, ResponseDone, ResponseFunctionCallArgumentsDelta, ResponseFunctionCallArgumentsDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json
from .realtime.connection import RealtimeApiConnection
from .tools import ClientToolCallResponse, ToolContext
from .utils import PCMWriter
Expand Down Expand Up @@ -240,6 +240,21 @@ async def model_to_rtc(self) -> None:
await pcm_writer.flush()
raise # Re-raise the cancelled exception to properly exit the task

async def handle_funtion_call(self, message: ResponseFunctionCallArgumentsDone) -> None:
function_call_response = await self.tools.execute_tool(message.name, message.arguments)
logger.info(f"Function call response: {function_call_response}")
await self.connection.send_request(
ItemCreate(
item = FunctionCallOutputItemParam(
call_id=message.call_id,
output=function_call_response.json_encoded_output
)
)
)
await self.connection.send_request(
ResponseCreate()
)

async def _process_model_messages(self) -> None:
async for message in self.connection.listen():
# logger.info(f"Received message {message=}")
Expand Down Expand Up @@ -312,5 +327,12 @@ async def _process_model_messages(self) -> None:
pass
case RateLimitsUpdated():
pass
case ResponseFunctionCallArgumentsDone():
asyncio.create_task(
self.handle_funtion_call(message)
)
case ResponseFunctionCallArgumentsDelta():
pass

case _:
logger.warning(f"Unhandled message {message=}")
3 changes: 3 additions & 0 deletions realtime_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError

from realtime_agent.realtime.tools_example import AgentTools

from .realtime.struct import PCM_CHANNELS, PCM_SAMPLE_RATE, ServerVADUpdateParams, Voices

from .agent import InferenceConfig, RealtimeKitAgent
Expand Down Expand Up @@ -82,6 +84,7 @@ def run_agent_in_process(
),
inference_config=inference_config,
tools=None,
# tools=AgentTools() # tools example, replace with this line
)
)

Expand Down
44 changes: 44 additions & 0 deletions realtime_agent/realtime/tools_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

from typing import Any
from realtime_agent.tools import ToolContext

# Function calling Example
# This is an example of how to add a new function to the agent tools.

class AgentTools(ToolContext):
def __init__(self) -> None:
super().__init__()

# create multiple functions here as per requirement
self.register_function(
name="get_avg_temp",
description="Returns average temperature of a country",
parameters={
"type": "object",
"properties": {
"country": {
"type": "string",
"description": "Name of country",
},
},
"required": ["country"],
},
fn=self._get_avg_temperature_by_country_name,
)

async def _get_avg_temperature_by_country_name(
self,
country: str,
) -> dict[str, Any]:
try:
result = "24 degree C" # Dummy data (Get the Required value here, like a DB call or API call)
return {
"status": "success",
"message": f"Average temperature of {country} is {result}",
"result": result,
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to get : {str(e)}",
}
16 changes: 6 additions & 10 deletions realtime_agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ class LocalFunctionToolDeclaration:
def model_description(self) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}


Expand All @@ -43,11 +41,9 @@ class PassThroughFunctionToolDeclaration:
def model_description(self) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}


Expand Down

0 comments on commit 9f9f2ad

Please sign in to comment.