Skip to content

Commit

Permalink
fix: change agent tool
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Aug 13, 2024
1 parent 4e0b704 commit 390800d
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
agent_definition:
agent_name: "main_function_create_agent"
agent_type: structured_chat
agent_type: tool_calling
agent_role: |
pythonプログラムのmain関数を作成する作業をお願いします。
別のagentはプログラムの分割やテストの作成を実施することになります。
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
agent_definition:
agent_name: "code_split_agent"
agent_type: structured_chat
agent_type: tool_calling
agent_role: |
プログラムを分割する作業をお願いします。
別のagentはプログラムの作成やテストの作成を実施することになります。
Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/agents/structured_chat/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def create_structured_chat_agent(
| llm_with_stop
)
agent = assign_runnable_history(agent, runnable_config)
agent = agent | output_parser
# agent = agent | output_parser
return agent


Expand Down
14 changes: 11 additions & 3 deletions src/codeinterpreterapi/crew/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from langchain_core.tools import BaseTool
from pydantic import Field

from codeinterpreterapi.utils.multi_converter import MultiConverter


class CustomAgent(BaseAgent):
agent_executor: Any = Field(default=None, description="Verbose mode for the Agent Execution")
Expand All @@ -28,7 +30,11 @@ def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
print("interpolate_inputs inputs=", inputs)
super().interpolate_inputs(inputs)

def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional[List[Any]] = None):
def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional[List[Any]] = None) -> str:
# Notice: ValidationError - 1 validation error for TaskOutput | raw: Input should be a valid string
# crewaiのTaskOutputのrawに入るのでstrで返す必要がある。
# TODO: 直接dictを返せるようにcrewaiを直す?

# AgentExecutorを使用してタスクを実行
# print("execute_task task=", task)
print("execute_task context=", context)
Expand All @@ -38,9 +44,11 @@ def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional
input_dict["question"] = task.prompt_context
input_dict["message"] = "タスクを実行してください。\n" + task.expected_output
result = self.agent_executor.invoke(input=input_dict)
print("execute_task result=", result)
result_str = MultiConverter.to_str(result)
print("execute_task result=", result_str)

# TODO: return full dict when crewai is updated
return result["output"]
return result_str

def create_agent_executor(self, tools=None) -> None:
pass
Expand Down
11 changes: 6 additions & 5 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from codeinterpreterapi.config import settings
from codeinterpreterapi.llm.llm import CodeInterpreterLlm
from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest
from codeinterpreterapi.utils.multi_converter import MultiConverter


def _handle_deprecated_kwargs(kwargs: dict) -> None:
Expand All @@ -49,6 +50,7 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""
print("AgentCallbackHandler on_chain_end type(outputs)=", type(outputs))
print("AgentCallbackHandler on_chain_end type(outputs)=", outputs)
self.agent_callback_func(outputs)

def on_chat_model_start(
Expand Down Expand Up @@ -228,9 +230,10 @@ async def _ainput_handler(self, request: UserRequest) -> None:

def _output_handler_pre(self, response: Any) -> str:
print("_output_handler_pre response(type)=", type(response))
if isinstance(response, str):
output_str = response
elif isinstance(response, dict):
output_str = MultiConverter.to_str(response)

# TODO: MultiConverterに共通化
if isinstance(response, dict):
output_str = ""
code_log_item = {}
if "output" in response:
Expand All @@ -242,8 +245,6 @@ def _output_handler_pre(self, response: Any) -> str:
if "log" in response:
code_log_item["log"] = str(response["log"])
self.output_code_log = code_log_item
else:
output_str = "response=" + str(response)
return output_str

def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse:
Expand Down
10 changes: 5 additions & 5 deletions src/codeinterpreterapi/tools/code_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.config import settings
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.schema import CodeInput
from codeinterpreterapi.schema import FileInput
from codeinterpreterapi.utils.file_util import FileUtil


Expand All @@ -27,13 +27,13 @@ def get_tools_code_checker(cls, ci_params: CodeInterpreterParams) -> None:
"Start from this latest code first.\n",
func=tools_instance._get_latest_code,
coroutine=tools_instance._aget_latest_code,
args_schema=CodeInput, # type: ignore
args_schema=FileInput, # type: ignore
),
]

return tools

def _get__get_latest_code_common(self, filename=""):
def _get_latest_code_common(self, filename=""):
target_dir = "./"
target_filename = "main.py"

Expand All @@ -51,10 +51,10 @@ def _get__get_latest_code_common(self, filename=""):
return ""

def _get_latest_code(self, filename=""):
return self._get__get_latest_code_common(filename)
return self._get_latest_code_common(filename)

async def _aget_latest_code(self, filename=""):
return self._get__get_latest_code_common(filename)
return self._get_latest_code_common(filename)


def test():
Expand Down

0 comments on commit 390800d

Please sign in to comment.