Skip to content

Commit

Permalink
fix: output type bug of crew_agent
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Aug 11, 2024
1 parent 5e1bda2 commit 3d8b0dd
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 64 deletions.
1 change: 1 addition & 0 deletions examples/show_bitcoin_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

def main() -> None:
with CodeInterpreterSession(is_local=True) as session:
session.start_local()
currentdate = datetime.now().strftime("%Y-%m-%d")

response = session.generate_response(f"Plot the bitcoin chart of 2023 YTD (today is {currentdate})")
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def choose_agent_executors(ci_params: CodeInterpreterParams) -> List[AgentExecut
config = yaml.safe_load(f)

try:
print(f"Agent: {agent_name}")
print(f"config: {config}")
print(f"choose_agent_executors Agent: {agent_name}")
# print(f"config: {config}")
agent_def = AgentDefinition(**config["agent_definition"])
agent_def.build_prompt()
print(agent_def)
Expand Down
88 changes: 71 additions & 17 deletions src/codeinterpreterapi/callbacks/markdown/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import os
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from uuid import UUID

from langchain.callbacks import FileCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
Expand All @@ -15,49 +16,102 @@ def __init__(self, filename: str = "langchain_log.md"):

def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.step_count += 1
self._write_to_file(f"## Step {self.step_count}: LLM Start\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
self._write_header("LLM Start")
self._write_serialized(serialized)
self._write_to_file("**Prompts:**\n\n")
for i, prompt in enumerate(prompts, 1):
self._write_to_file(f"```\nPrompt {i}:\n{prompt}\n```\n\n")

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._write_to_file("**LLM Response:**\n\n")
self._write_header("LLM Response")
for generation in response.generations[0]:
self._write_to_file(f"```\n{generation.text}\n```\n\n")
self._write_to_file("---\n\n")

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
self.step_count += 1
chain_name = serialized.get("name", "Unknown Chain")
self._write_to_file(f"## Step {self.step_count}: Chain Start - {chain_name}\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
self._write_to_file("**Inputs:**\n\n")
self._write_to_file(f"```\n{inputs}\n```\n\n")
self._write_header(f"Chain Start - {chain_name}")
self._write_serialized(serialized)
self._write_inputs(inputs)

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
self._write_to_file("**Outputs:**\n\n")
self._write_to_file(f"```\n{outputs}\n```\n\n")
self._write_to_file("---\n\n")
self._write_header("Chain End")
self._write_outputs(outputs)

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
self.step_count += 1
self._write_to_file(f"## Step {self.step_count}: Agent Action\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
self._write_header("Agent Action")
self._write_to_file(f"**Tool:** {action.tool}\n\n")
self._write_to_file("**Tool Input:**\n\n")
self._write_to_file(f"```\n{action.tool_input}\n```\n\n")

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
self._write_to_file("## Agent Finish\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
self._write_header("Agent Finish")
self._write_to_file("**Output:**\n\n")
self._write_to_file(f"```\n{finish.return_values}\n```\n\n")
self._write_to_file("---\n\n")

def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self.step_count += 1
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
self._write_header(f"Tool Start - {class_name}")
self._write_serialized(serialized)
self._write_inputs(inputs)

def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._write_header("Tool End")
self._write_to_file(f"**Output:**{output}\n\n")

def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._write_header("Tool Error")
self._write_to_file(f"{str(error)}\n\n")

def _write_to_file(self, text: str) -> None:
self.file.write(text)
self.file.flush()

def _get_timestamp(self) -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

def _write_header(self, title: str):
self._write_to_file(f"## Step {self.step_count}: {title}\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")

def _write_serialized(self, serialized: Dict[str, Any]):
name = serialized.get("name", "Unknown")
self._write_to_file(f"**serialized.name:** {name}\n\n")
self._write_to_file(f"**serialized.keys:** {serialized.keys()}\n\n")

def _write_inputs(self, inputs: Dict[str, Any]):
self._write_to_file("**Inputs:**\n\n")
self._write_to_file(f"```\n{inputs}\n```\n\n")

def _write_outputs(self, outputs: Dict[str, Any]):
self._write_to_file("**Outputs:**\n\n")
self._write_to_file(f"```\n{outputs}\n```\n\n")
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/crew/crew_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test():
_ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
inputs = {"input": TestPrompt.svg_input_str}
plan = CodeInterpreterPlan(agent_name="main_function_create_agent", task_description="", expected_output="")
plan_list = CodeInterpreterPlanList(agent_task_list=[plan, plan])
plan_list = CodeInterpreterPlanList(reliability=80, agent_task_list=[plan, plan])
result = CodeInterpreterCrew(ci_params).run(inputs, plan_list)
print(result)

Expand Down
8 changes: 7 additions & 1 deletion src/codeinterpreterapi/crew/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:

def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional[List[Any]] = None):
# AgentExecutorを使用してタスクを実行
print("execute_task task=", task)
# print("execute_task task=", task)
print("execute_task context=", context)
print("execute_task tools=", tools)
input_dict = {}
input_dict["input"] = task.description
input_dict["question"] = task.prompt_context
input_dict["message"] = "タスクを実行してください。\n" + task.expected_output
result = self.agent_executor.invoke(input=input_dict)
print("execute_task result=", result)
# TODO: return full dict when crewai is updated
return result["output"]

def create_agent_executor(self, tools=None) -> None:
Expand All @@ -54,6 +56,10 @@ def get_delegation_tools(self, agents: List[BaseAgent]):
return []

def get_output_converter(self, llm, text, model, instructions):
print("get_output_converter llm=", type(llm))
print("get_output_converter text=", type(text))
print("get_output_converter model=", type(model))
print("get_output_converter instructions=", type(instructions))
return lambda x: x # デフォルトでは変換なし

def execute(self, task_description: str, context: Optional[List[str]] = None):
Expand Down
4 changes: 1 addition & 3 deletions src/codeinterpreterapi/planners/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ class CodeInterpreterPlanOutput(BaseModel):

class CustomPydanticOutputParser(PydanticOutputParser):
def parse(self, text) -> CodeInterpreterPlanList:
print("parse text=", text)
input_data = self.preprocess_input(text)
return super().parse(input_data)

def preprocess_input(self, input_data) -> str:
print("preprocess_input input_data=", input_data)
print("preprocess_input type input_data=", type(input_data))
if isinstance(input_data, AIMessage):
return input_data.content
Expand All @@ -71,7 +69,7 @@ def choose_planner(ci_params: CodeInterpreterParams) -> Union[Runnable, AgentExe
else:
prompt = create_planner_agent_prompt()
prompt = PromptUpdater.update_prompt(prompt, ci_params)
PromptUpdater.show_prompt(prompt)
# PromptUpdater.show_prompt(prompt)

# structured_llm
# structured_llm = ci_params.llm.bind_tools(tools=[CodeInterpreterPlanList]) # なぜか空のAgentPlanが生成される
Expand Down
14 changes: 8 additions & 6 deletions src/codeinterpreterapi/planners/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@

SYSTEM_MESSAGE_TEMPLATE_JA = '''
あなたは優秀なAIエージェントを管理するシニアエンジニアです。
次の明確な手続きを実施して、問題を理解し、問題を解決するための計画を立ててください
次の手順で問題を解決するための計画(CodeInterpreterPlanList)を作成してください
# 手順
手順1: 問題を理解する
手順2: 利用可能なAI agentのリスト(agent_info)を確認する
手順3: 問題解決に最適なAI agentがあるか判断する
手順4: CodeInterpreterPlanList を作成して計画を回答する
手順3: 問題解決に利用するべきAI agentをピックアップする
手順4: 最適な順番でAI agentを利用するようにinput/outputなどを検討する
手順5: CodeInterpreterPlanList として最終的な計画を出力する
利用可能なAI agentのリスト:
# 利用可能なAI agent
{agent_info}
制約条件:
# 制約条件
- ステップバイステップで精密に思考し回答する。
- 作業として何を求められているか正しく理解する。
- AI agentの機能を正確に理解してから回答する。
Expand All @@ -40,7 +42,7 @@
-- 何らかの理由で作業の実現が困難な場合
- 各ステップの思考と出力は日本語とする。
問題は以下に示します。注意深く問題を理解して回答してください。
# 問題
'''


Expand Down
27 changes: 17 additions & 10 deletions src/codeinterpreterapi/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, List
from typing import Any, List, Optional

from codeboxapi.schema import CodeBoxStatus
from langchain_core.messages import AIMessage, HumanMessage
Expand Down Expand Up @@ -113,8 +113,8 @@ class CodeInterpreterResponse(AIMessage):
code_log: list[tuple[str, str]] = []
"""

files: list[File] = []
code_log: list[tuple[str, str]] = []
files: Optional[list[File]] = []
code_log: Optional[dict[str, str]] = []
agent_name: str = ""

def show(self) -> None:
Expand All @@ -140,18 +140,25 @@ def __repr__(self) -> str:


class CodeInterpreterPlan(BaseModel):
'''Agent and Task definition. Plan and task is 1:1.'''
'''単一PlanのAgentとTaskの説明です。
PlanとAgentとTaskは常に1:1:1の関係です。
'''

agent_name: str = Field(
description="The agent name for task. This is primary key. Agent responsible for task execution. Represents entity performing task."
description="Agentの名前です。タスクの名前も、このagent_nameと常に同じになり、primary keyとして使われます。利用可能な文字は[a-Z_]です。task likeな名前にしてください。"
)
task_description: str = Field(
description="タスクの説明です。可能な範囲でpurpose, execution plan, input, outputの詳細を含めます。"
)
expected_output: str = Field(
description="タスクの最終的な出力形式を明確に定義します。例えばjson/csvというフォーマットや、カラム名やサイズ情報です。"
)
task_description: str = Field(description="Descriptive text detailing task's purpose and execution.")
expected_output: str = Field(description="Clear definition of expected task outcome.")


class CodeInterpreterPlanList(BaseModel):
'''Sequential plans for the task.'''
'''CodeInterpreterPlanの配列をもつ計画全体です。'''

agent_task_list: List[CodeInterpreterPlan] = Field(
description="The list of CodeInterpreterPlan. It means agent name and so on."
reliability: int = Field(
description="計画の信頼度[0-100]です。100が完全な計画を意味します。50未満だと不完全計画でオリジナルの問題を直接llmに渡した方が良い結果になります。"
)
agent_task_list: List[CodeInterpreterPlan] = Field(description="CodeInterpreterPlanの配列です。")
Loading

0 comments on commit 3d8b0dd

Please sign in to comment.