diff --git a/ai/agents/oai/completion.py b/ai/agents/oai/completion.py index bd12c0d9..e6d99b47 100644 --- a/ai/agents/oai/completion.py +++ b/ai/agents/oai/completion.py @@ -303,6 +303,18 @@ def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_ca del config['agent_name'] if "api_type" in config: del config['api_type'] + + tools = [] + if "functions" in config: + functions = config.pop("functions") + for function in functions: + tools.append({ + "type": "function", + "function": function + }) + if len(tools) > 0: + config['tools'] = tools + openai_completion = ( openai.ChatCompletion if config["model"].replace("gpt-35-turbo", "gpt-3.5-turbo") in cls.chat_models @@ -313,6 +325,14 @@ def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_ca response = openai_completion.create(**config) else: response = openai_completion.create(request_timeout=request_timeout, **config) + if "tool_calls" in response.choices[0].message and len(response.choices[0].message.tool_calls) > 0: + tool_calls = response.choices[0].message.pop("tool_calls") + function_call = tool_calls[0] + response.choices[0].message['function_call'] = { + "name": function_call.function.name, + "arguments": function_call.function.arguments + } + response.choices[0]['finish_reason'] = "function_call" except ( ServiceUnavailableError,