Skip to content

Commit

Permalink
fix: update about CodeInterpreterSession
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Aug 9, 2024
1 parent 94ebe19 commit 6fabd38
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 33 deletions.
4 changes: 1 addition & 3 deletions examples/chat_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from codeinterpreterapi import CodeInterpreterSession, settings

settings.MODEL = "gpt-4"
from codeinterpreterapi import CodeInterpreterSession

print("AI: Hello, I am the " "code interpreter agent.\n" "Ask me todo something and " "I will use python to do it!\n")

Expand Down
2 changes: 1 addition & 1 deletion examples/show_bitcoin_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def main() -> None:
with CodeInterpreterSession(local=True) as session:
with CodeInterpreterSession(is_local=True) as session:
currentdate = datetime.now().strftime("%Y-%m-%d")

response = session.generate_response(f"Plot the bitcoin chart of 2023 YTD (today is {currentdate})")
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/brain/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def batch(self, inputs: List[Output]) -> List[Output]:
return [self.run(input_item) for input_item in inputs]

async def ainvoke(self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Output:
raise NotImplementedError("Async not implemented yet")
return self.run(input, config)

async def abatch(
self,
Expand All @@ -156,7 +156,7 @@ async def abatch(
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
raise NotImplementedError("Async not implemented yet")
return [self.run(input_item) for input_item in inputs]

def update_agent_score(self):
self.current_agent_score = CodeInterpreterBrain.AGENT_SCORE_MIN - 1 # temp: switch every time
Expand Down
16 changes: 12 additions & 4 deletions src/codeinterpreterapi/planners/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,23 @@
次の明確な手続きを実施して、問題を理解し、問題を解決するための計画を立ててください。
手順1: 問題を理解する
手順2: 利用可能なAI agentのリストを確認する
手順3: CodeInterpreterPlanList を作成して計画を回答する
手順2: 利用可能なAI agentのリスト(agent_info)を確認する
手順3: 問題解決に最適なAI agentがあるか判断する
手順4: CodeInterpreterPlanList を作成して計画を回答する
利用可能なAI agentのリスト:
{agent_info}
制約条件:
- 段階的に考え、各ステップで取るアクションを明確にすること。
- agent_infoに示されたagent_name以外のagentを利用しないこと。
- ステップバイステップで精密に思考し回答する。
- 作業として何を求められているか正しく理解する。
- AI agentの機能を正確に理解してから回答する。
- 各ステップの入力と出力を明確にする。
- agent_infoに示されたagent_name以外のagentを利用しない。
- 次の場合は計画を作成せずに長さ0のリストを返す。
-- 利用可能なagentが不足している
-- 計画を立てるまでもない簡単な問題の場合
-- 何らかの理由で作業の実現が困難な場合
- 各ステップの思考と出力は日本語とする。
問題は以下に示します。注意深く問題を理解して回答してください。
Expand Down
52 changes: 33 additions & 19 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import UUID

from codeboxapi import CodeBox # type: ignore
from codeboxapi.schema import CodeBoxStatus # type: ignore
from langchain.callbacks.base import BaseCallbackHandler, Callbacks
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory
Expand Down Expand Up @@ -139,29 +140,35 @@ def session_id(self) -> Optional[UUID]:

def start(self) -> SessionStatus:
print("start")
status = SessionStatus.from_codebox_status(self.ci_params.codebox.start())
codebox_status = CodeBoxStatus(status="unknown")
if self.ci_params.codebox:
codebox_status = self.ci_params.codebox.start()
self.ci_params.codebox.run(
f"!pip install -q {' '.join(settings.CUSTOM_PACKAGES)}",
)
self.brain.initialize()
self.ci_params.codebox.run(
f"!pip install -q {' '.join(settings.CUSTOM_PACKAGES)}",
)
return status
return SessionStatus.from_codebox_status(codebox_status)

async def astart(self) -> SessionStatus:
print("astart")
status = SessionStatus.from_codebox_status(await self.ci_params.codebox.astart())
codebox_status = CodeBoxStatus(status="unknown")
if self.ci_params.codebox:
codebox_status = self.ci_params.codebox.astart()
self.ci_params.codebox.arun(
f"!pip install -q {' '.join(settings.CUSTOM_PACKAGES)}",
)
self.brain.initialize()
await self.ci_params.codebox.arun(
f"!pip install -q {' '.join(settings.CUSTOM_PACKAGES)}",
)
return status
return SessionStatus.from_codebox_status(codebox_status)

def start_local(self) -> SessionStatus:
# TODO: delete it and use start()
print("start_local")
self.brain.initialize()
status = SessionStatus(status="started")
return status

async def astart_local(self) -> SessionStatus:
# TODO: delete it and use astart()
print("astart_local")
status = self.start_local()
self.brain.initialize()
Expand Down Expand Up @@ -260,8 +267,9 @@ def _output_handler(self, response: Any) -> CodeInterpreterResponse:
response = self._output_handler_post(final_response)
return response

async def _aoutput_handler(self, final_response: str) -> CodeInterpreterResponse:
async def _aoutput_handler(self, response: str) -> CodeInterpreterResponse:
"""Embed images in the response"""
final_response = self._output_handler_pre(response)
for file in self.output_files:
if str(file.name) in final_response:
# rm ![Any](file.name) from the response
Expand Down Expand Up @@ -335,13 +343,12 @@ async def agenerate_response(
user_request = UserRequest(content=user_msg, files=files)
try:
await self._ainput_handler(user_request)

agent_scratchpad = ""
input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad}
# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
response = await self.brain.ainvoke(input=user_request.content)
response = await self.brain.ainvoke(input=input_message)
# ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #=======

output_str = self._output_handler_pre(response)
return await self._aoutput_handler(output_str)
return await self._aoutput_handler(response)
except Exception as e:
if self.verbose:
traceback.print_exc()
Expand Down Expand Up @@ -446,13 +453,19 @@ def log(self, msg: str) -> None:
print(msg)

def stop(self) -> SessionStatus:
return SessionStatus.from_codebox_status(self.ci_params.codebox.stop())
codebox_status = CodeBoxStatus(status="unknown")
if self.ci_params.codebox:
codebox_status = self.ci_params.codebox.stop()
return SessionStatus.from_codebox_status(codebox_status)

async def astop(self) -> SessionStatus:
return SessionStatus.from_codebox_status(await self.ci_params.codebox.astop())
codebox_status = CodeBoxStatus(status="unknown")
if self.ci_params.codebox:
codebox_status = await self.ci_params.codebox.astop()
return SessionStatus.from_codebox_status(codebox_status)

def __enter__(self) -> "CodeInterpreterSession":
if self.is_local:
if self.ci_params.is_local:
self.start_local()
else:
self.start()
Expand All @@ -477,3 +490,4 @@ async def __aexit__(
traceback: Optional[TracebackType],
) -> None:
await self.astop()
await self.astop()
25 changes: 21 additions & 4 deletions src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, planner: Runnable, ci_params: CodeInterpreterParams):
self.planner = ci_params.planner_agent
self.ci_params = ci_params
self.supervisor_chain = None
self.supervisor_chain_no_agent = None
self.initialize()

def initialize(self) -> None:
Expand Down Expand Up @@ -64,6 +65,9 @@ class RouteSchema(BaseModel):
# if self.ci_params.runnable_config:
# self.supervisor_chain = self.supervisor_chain.with_config(self.ci_params.runnable_config)

# supervisor_chain_no_agent
self.supervisor_chain_no_agent = self.ci_params.llm

def get_executor(self) -> AgentExecutor:
# TODO: use own executor(not crewai)
# agent_executor
Expand All @@ -79,7 +83,17 @@ def get_executor(self) -> AgentExecutor:
def invoke(self, input: Input) -> Output:
result = self.planner.invoke(input)
print("type result=", type(result))
result = self.ci_params.crew_agent.run(input, result)
if isinstance(result, CodeInterpreterPlanList):
plan_list: CodeInterpreterPlanList = result
if len(plan_list.agent_task_list) > 0:
print("supervisor.invoke use crew_agent plan_list=", plan_list)
result = self.ci_params.crew_agent.run(input, result)
else:
print("supervisor.invoke no_agent plan_list=", plan_list)
result = self.supervisor_chain_no_agent.invoke(input)
else:
print("supervisor.invoke no_agent no plan_list")
result = self.supervisor_chain_no_agent.invoke(input)
return result

def execute_plan(self, plan_list: CodeInterpreterPlanList) -> Dict[str, Any]:
Expand Down Expand Up @@ -119,16 +133,19 @@ def execute_plan(self, plan_list: CodeInterpreterPlanList) -> Dict[str, Any]:


def test():
use_simple_prompt = True
if use_simple_prompt:
test_prompt = TestPrompt.python_input_str
else:
test_prompt = TestPrompt.svg_input_str
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
_ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
crew_agent = CodeInterpreterCrew(ci_params=ci_params)
ci_params.crew_agent = crew_agent
planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params)
supervisor = CodeInterpreterSupervisor(planner=planner, ci_params=ci_params)
result = supervisor.invoke(
{"input": TestPrompt.svg_input_str, "agent_scratchpad": "", "messages": [TestPrompt.svg_input_str]}
)
result = supervisor.invoke({"input": test_prompt, "agent_scratchpad": "", "messages": [test_prompt]})
print("result=", result)


Expand Down

0 comments on commit 6fabd38

Please sign in to comment.