From 7ac48daee19e80de1384146e406f27b73d3309ee Mon Sep 17 00:00:00 2001 From: jinno Date: Thu, 30 May 2024 03:56:04 +0900 Subject: [PATCH] fix: add CodeInterpreterParams --- src/codeinterpreterapi/__init__.py | 2 + src/codeinterpreterapi/agents/agents.py | 28 ++++---- .../agents/plan_and_execute/agent_executor.py | 29 +++----- src/codeinterpreterapi/planners/planners.py | 22 +++--- src/codeinterpreterapi/session.py | 68 ++++--------------- .../supervisors/supervisors.py | 10 ++- src/codeinterpreterapi/thoughts/thoughts.py | 17 +++-- src/codeinterpreterapi/tools/python.py | 35 +++++----- src/codeinterpreterapi/tools/tools.py | 11 ++- 9 files changed, 88 insertions(+), 134 deletions(-) diff --git a/src/codeinterpreterapi/__init__.py b/src/codeinterpreterapi/__init__.py index 6c6061c3..b4b3e5ee 100644 --- a/src/codeinterpreterapi/__init__.py +++ b/src/codeinterpreterapi/__init__.py @@ -1,3 +1,4 @@ +from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.config import settings from codeinterpreterapi.schema import File from codeinterpreterapi.session import CodeInterpreterSession @@ -5,6 +6,7 @@ from . import _patch_parser # noqa __all__ = [ + "CodeInterpreterParams", "CodeInterpreterSession", "File", "settings", diff --git a/src/codeinterpreterapi/agents/agents.py b/src/codeinterpreterapi/agents/agents.py index 0f0d0da1..831145c3 100644 --- a/src/codeinterpreterapi/agents/agents.py +++ b/src/codeinterpreterapi/agents/agents.py @@ -14,40 +14,40 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI from codeinterpreterapi.agents.plan_and_execute.agent_executor import load_agent_executor +from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.config import settings class CodeInterpreterAgent: @staticmethod - def create_agent_executor_lcel(llm, tools, verbose=False, chat_memory=None, callbacks=None, is_ja=True) -> Runnable: + def create_agent_executor_lcel(ci_params: CodeInterpreterParams) -> Runnable: # prompt prompt = hub.pull("hwchase17/openai-functions-agent") # agent - agent = create_tool_calling_agent(llm, tools, prompt) + agent = create_tool_calling_agent(ci_params.llm, ci_params.tools, prompt) # agent_executor agent_executor = AgentExecutor( agent=agent, - tools=tools, - verbose=verbose, + tools=ci_params.tools, + verbose=ci_params.verbose, # memory=ConversationBufferMemory( # memory_key="chat_history", # return_messages=True, # chat_memory=chat_memory, # ), - callbacks=callbacks, + callbacks=ci_params.callbacks, ) print("agent_executor.input_keys", agent_executor.input_keys) print("agent_executor.output_keys", agent_executor.output_keys) return agent_executor @staticmethod - def choose_single_chat_agent( - llm, - tools, - is_ja, - ) -> BaseSingleActionAgent: + def choose_single_chat_agent(ci_params: CodeInterpreterParams) -> BaseSingleActionAgent: + llm = ci_params.llm + tools = ci_params.tools + is_ja = ci_params.is_ja system_message = settings.SYSTEM_MESSAGE if is_ja else settings.SYSTEM_MESSAGE_JA if isinstance(llm, ChatOpenAI) or isinstance(llm, AzureChatOpenAI): print("choose_agent OpenAIFunctionsAgent") @@ -80,16 +80,16 @@ def choose_single_chat_agent( ) @staticmethod - def create_agent_and_executor(llm, tools, verbose, chat_memory, callbacks, is_ja=True) -> AgentExecutor: + def create_agent_and_executor(ci_params: CodeInterpreterParams) -> AgentExecutor: # agent - agent = CodeInterpreterAgent.choose_single_chat_agent(llm, tools, is_ja=is_ja) + agent = CodeInterpreterAgent.choose_single_chat_agent(ci_params) print("create_agent_and_executor agent=", str(type(agent))) return agent @staticmethod - def create_agent_and_executor_experimental(llm, tools, verbose, is_ja) -> AgentExecutor: + def create_agent_and_executor_experimental(ci_params: CodeInterpreterParams) -> AgentExecutor: # agent_executor - agent_executor = load_agent_executor(llm, tools, verbose=verbose, is_ja=is_ja) + agent_executor = load_agent_executor(ci_params) print("create_agent_and_executor_experimental") return agent_executor diff --git a/src/codeinterpreterapi/agents/plan_and_execute/agent_executor.py b/src/codeinterpreterapi/agents/plan_and_execute/agent_executor.py index 8257beb5..23d1d220 100644 --- a/src/codeinterpreterapi/agents/plan_and_execute/agent_executor.py +++ b/src/codeinterpreterapi/agents/plan_and_execute/agent_executor.py @@ -1,24 +1,12 @@ -from typing import List, Optional - -from langchain.agents.agent import AgentExecutor, AgentOutputParser +from langchain.agents.agent import AgentExecutor from langchain.agents.structured_chat.base import create_structured_chat_agent -from langchain.tools import BaseTool -from langchain_core.callbacks import BaseCallbackManager -from langchain_core.language_models import BaseLanguageModel from codeinterpreterapi.agents.plan_and_execute.prompts import create_structured_chat_agent_prompt +from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.llm.llm import prepare_test_llm -def load_agent_executor( - llm: BaseLanguageModel, - tools: List[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - output_parser: Optional[AgentOutputParser] = None, - verbose: bool = False, - include_task_in_prompt: bool = False, - is_ja: str = True, -) -> AgentExecutor: +def load_agent_executor(ci_params: CodeInterpreterParams) -> AgentExecutor: """ Load an agent executor. @@ -33,11 +21,11 @@ def load_agent_executor( """ input_variables = ["previous_steps", "current_step", "agent_scratchpad", "tools", "tool_names"] print("input_variables=", input_variables) - prompt = create_structured_chat_agent_prompt(is_ja) + prompt = create_structured_chat_agent_prompt(ci_params.is_ja) print("prompt=", prompt.get_prompts()) agent = create_structured_chat_agent( - llm=llm, - tools=tools, + llm=ci_params.llm, + tools=ci_params.tools, # callback_manager=callback_manager, # output_parser=output_parser, # prefix=tools_prefix, @@ -48,13 +36,14 @@ def load_agent_executor( # memory_prompts = memory_prompts, ) - agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=verbose) + agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose) return agent_executor def test(): llm = prepare_test_llm() - agent_executor = load_agent_executor(llm=llm, tools=[]) + ci_params = CodeInterpreterParams.get_test_params(llm=llm) + agent_executor = load_agent_executor(ci_params) test_input = "pythonで円周率を表示するプログラムを実行してください。" agent_executor_output = agent_executor.invoke({"input": test_input}) print("agent_executor_output=", agent_executor_output) diff --git a/src/codeinterpreterapi/planners/planners.py b/src/codeinterpreterapi/planners/planners.py index 232ee246..5e819053 100644 --- a/src/codeinterpreterapi/planners/planners.py +++ b/src/codeinterpreterapi/planners/planners.py @@ -1,15 +1,13 @@ -from typing import List - from langchain import hub from langchain.agents import create_react_agent -from langchain.base_language import BaseLanguageModel -from langchain.tools import BaseTool from langchain_core.runnables import Runnable +from codeinterpreterapi.brain.params import CodeInterpreterParams + class CodeInterpreterPlanner: @staticmethod - def choose_planner(llm: BaseLanguageModel, tools: List[BaseTool], is_ja: bool) -> Runnable: + def choose_planner(ci_params: CodeInterpreterParams) -> Runnable: """ Load a chat planner. @@ -20,11 +18,19 @@ def choose_planner(llm: BaseLanguageModel, tools: List[BaseTool], is_ja: bool) - Returns: LLMPlanner - """ + + Input + tools: + tool_names: + input: + agent_scratchpad: + Output + content: Free text in str. + """ prompt_name = "nobu/simple_react" - if is_ja: + if ci_params.is_ja: prompt_name += "_ja" prompt = hub.pull(prompt_name) - planner_agent = create_react_agent(llm, tools, prompt) + planner_agent = create_react_agent(ci_params.llm, ci_params.tools, prompt) return planner_agent diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 32392f1f..3bda7449 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -6,27 +6,22 @@ from codeboxapi import CodeBox # type: ignore from gui_agent_loop_core.schema.schema import GuiAgentInterpreterChatResponseStr -from langchain.agents import AgentExecutor from langchain.callbacks.base import Callbacks from langchain_community.chat_message_histories.in_memory import ChatMessageHistory from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory from langchain_community.chat_message_histories.redis import RedisChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.language_models import BaseLanguageModel -from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from codeinterpreterapi.agents.agents import CodeInterpreterAgent +from codeinterpreterapi.brain.brain import CodeInterpreterBrain +from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.chains import aremove_download_link, remove_download_link from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory from codeinterpreterapi.config import settings from codeinterpreterapi.llm.llm import CodeInterpreterLlm from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest -from codeinterpreterapi.thoughts.thoughts import CodeInterpreterToT - -from .planners.planners import CodeInterpreterPlanner -from .supervisors.supervisors import CodeInterpreterSupervisor -from .tools.tools import CodeInterpreterTools def _handle_deprecated_kwargs(kwargs: dict) -> None: @@ -49,19 +44,19 @@ def __init__( **kwargs: Any, ) -> None: _handle_deprecated_kwargs(kwargs) - self.is_local = is_local - self.is_ja = is_ja - self.codebox = CodeBox(requirements=settings.CUSTOM_PACKAGES) self.verbose = kwargs.get("verbose", settings.DEBUG) - self.llm: BaseLanguageModel = llm or CodeInterpreterLlm.get_llm() - self.tools: list[BaseTool] = CodeInterpreterTools(additional_tools, self.llm).get_all_tools() + self.llm: BaseLanguageModel = llm or CodeInterpreterLlm.get_llm() # TODO: remove from session + ci_params = CodeInterpreterParams( + llm=self.llm, + tools=additional_tools, + callbacks=callbacks, + verbose=self.verbose, + is_local=is_local, + is_ja=is_ja, + ) + self.brain = CodeInterpreterBrain(ci_params) self.log("self.llm=" + str(self.llm)) - self.callbacks = callbacks - self.agent_executor: Optional[Runnable] = None - self.llm_planner: Optional[Runnable] = None - self.supervisor: Optional[AgentExecutor] = None - self.thought: Optional[Runnable] = None self.input_files: list[File] = [] self.output_files: list[File] = [] self.code_log: list[tuple[str, str]] = [] @@ -77,45 +72,6 @@ def from_id(cls, session_id: UUID, **kwargs: Any) -> "CodeInterpreterSession": def session_id(self) -> Optional[UUID]: return self.codebox.session_id - def initialize(self): - self.initialize_agent_executor() - self.initialize_llm_planner() - self.initialize_supervisor() - self.initialize_thought() - - def initialize_agent_executor(self): - is_experimental = True - if is_experimental: - self.agent_executor = CodeInterpreterAgent.create_agent_and_executor_experimental( - llm=self.llm, - tools=self.tools, - verbose=self.verbose, - is_ja=self.is_ja, - ) - else: - self.agent_executor = CodeInterpreterAgent.create_agent_executor_lcel( - llm=self.llm, - tools=self.tools, - verbose=self.verbose, - is_ja=self.is_ja, - chat_memory=self._history_backend(), - callbacks=self.callbacks, - ) - - def initialize_llm_planner(self): - self.llm_planner = CodeInterpreterPlanner.choose_planner(llm=self.llm, tools=self.tools, is_ja=self.is_ja) - - def initialize_supervisor(self): - self.supervisor = CodeInterpreterSupervisor.choose_supervisor( - planner=self.llm_planner, - executor=self.agent_executor, - tools=self.tools, - verbose=self.verbose, - ) - - def initialize_thought(self): - self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm, is_ja=self.is_ja, is_simple=False) - def start(self) -> SessionStatus: print("start") status = SessionStatus.from_codebox_status(self.codebox.start()) diff --git a/src/codeinterpreterapi/supervisors/supervisors.py b/src/codeinterpreterapi/supervisors/supervisors.py index c2935a66..51dacf88 100644 --- a/src/codeinterpreterapi/supervisors/supervisors.py +++ b/src/codeinterpreterapi/supervisors/supervisors.py @@ -1,25 +1,23 @@ import getpass import os import platform -from typing import List from langchain.agents import AgentExecutor -from langchain.tools import BaseTool from langchain_core.runnables import Runnable +from codeinterpreterapi.brain.params import CodeInterpreterParams + class CodeInterpreterSupervisor: @staticmethod - def choose_supervisor( - planner: Runnable, executor: Runnable, tools: List[BaseTool], verbose: bool = False - ) -> AgentExecutor: + def choose_supervisor(planner: Runnable, ci_params: CodeInterpreterParams) -> AgentExecutor: # prompt username = getpass.getuser() current_working_directory = os.getcwd() operating_system = platform.system() info = f"[User Info]\nName: {username}\nCWD: {current_working_directory}\nOS: {operating_system}" print("choose_supervisor info=", info) - agent_executor = AgentExecutor(agent=planner, tools=tools, verbose=verbose) + agent_executor = AgentExecutor(agent=planner, tools=ci_params.tools, verbose=ci_params.verbose) # prompt = hub.pull("nobu/chat_planner") # agent = create_react_agent(llm, [], prompt) # return agent diff --git a/src/codeinterpreterapi/thoughts/thoughts.py b/src/codeinterpreterapi/thoughts/thoughts.py index fccbd1e7..72a0208f 100644 --- a/src/codeinterpreterapi/thoughts/thoughts.py +++ b/src/codeinterpreterapi/thoughts/thoughts.py @@ -1,11 +1,14 @@ from typing import Any, Dict, List, Optional, Union -from codeinterpreterapi.thoughts.base import MyToTChain -from codeinterpreterapi.thoughts.checker import create_tot_chain_from_llm from langchain_core.runnables import RunnableSerializable from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import Input, Output +from codeinterpreterapi.brain.params import CodeInterpreterParams +from codeinterpreterapi.llm.llm import prepare_test_llm +from codeinterpreterapi.thoughts.base import MyToTChain +from codeinterpreterapi.thoughts.checker import create_tot_chain_from_llm + class CodeInterpreterToT(RunnableSerializable): tot_chain: MyToTChain = None @@ -14,7 +17,7 @@ def __init__(self, llm=None, is_ja=True, is_simple=False): super().__init__() self.tot_chain = create_tot_chain_from_llm(llm=llm, is_ja=is_ja, is_simple=is_simple) - def run(self, input: Input): + def run(self, input: Input) -> Output: problem_description = input["input"] return self.tot_chain.run(problem_description=problem_description) @@ -41,14 +44,16 @@ async def abatch( raise NotImplementedError("Async not implemented yet") @classmethod - def get_runnable_tot_chain(cls, llm=None, is_ja=True, is_simple=False): + def get_runnable_tot_chain(cls, ci_params: CodeInterpreterParams, is_simple: bool = False): # ToTChainのインスタンスを作成 - tot_chain = cls(llm=llm, is_ja=is_ja, is_simple=is_simple) + tot_chain = cls(llm=ci_params.llm, is_ja=ci_params.is_ja, is_simple=is_simple) return tot_chain def test(): - tot_chain = CodeInterpreterToT.get_runnable_tot_chain(is_simple=True) + llm = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm) + tot_chain = CodeInterpreterToT.get_runnable_tot_chain(ci_params=ci_params) tot_chain.invoke({"input": sample2}) diff --git a/src/codeinterpreterapi/tools/python.py b/src/codeinterpreterapi/tools/python.py index 3c0dae4e..9306d843 100644 --- a/src/codeinterpreterapi/tools/python.py +++ b/src/codeinterpreterapi/tools/python.py @@ -6,27 +6,25 @@ from io import BytesIO from uuid import uuid4 -from codeboxapi import CodeBox # type: ignore from codeboxapi.schema import CodeBoxOutput # type: ignore from langchain_core.tools import StructuredTool +from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.config import settings from codeinterpreterapi.llm.llm import prepare_test_llm from codeinterpreterapi.schema import CodeInput, File class PythonTools: - def __init__(self, llm, codebox: CodeBox = None, verbose: bool = False): - self.codebox = codebox - self.llm = llm - self.verbose = verbose + def __init__(self, ci_params: CodeInterpreterParams): + self.ci_params = ci_params self.code_log = [] self.input_files = [] self.output_files = [] @classmethod - def get_tools_python(cls, llm) -> None: - tools_instance = cls(llm=llm) + def get_tools_python(cls, ci_params: CodeInterpreterParams) -> None: + tools_instance = cls(ci_params=ci_params) tools = [ StructuredTool( name="python", @@ -92,9 +90,9 @@ async def _arun_handler_local(self, code: str): def _run_handler(self, code: str) -> str: """Run code in container and send the output to the user""" self.show_code(code) - if self.codebox is None: + if self.ci_params.codebox is None: return self._run_handler_local(code) - output: CodeBoxOutput = self.codebox.run(code) + output: CodeBoxOutput = self.ci_params.codebox.run(code) self.code_log.append((code, output.content)) if not isinstance(output.content, str): @@ -113,12 +111,12 @@ def _run_handler(self, code: str) -> str: r"ModuleNotFoundError: No module named '(.*)'", output.content, ): - self.codebox.install(package.group(1)) + self.ci_params.codebox.install(package.group(1)) return f"{package.group(1)} was missing but got installed now. Please try again." else: # TODO: pre-analyze error to optimize next code generation pass - if self.verbose: + if self.ci_params.verbose: print("Error:", output.content) # elif modifications := get_file_modifications(code, self.llm): @@ -137,9 +135,9 @@ def _run_handler(self, code: str) -> str: async def _arun_handler(self, code: str) -> str: """Run code in container and send the output to the user""" await self.ashow_code(code) - if self.codebox is None: + if self.ci_params.codebox is None: return self._arun_handler_local(code) - output: CodeBoxOutput = await self.codebox.arun(code) + output: CodeBoxOutput = await self.ci_params.codebox.arun(code) self.code_log.append((code, output.content)) if not isinstance(output.content, str): @@ -158,12 +156,12 @@ async def _arun_handler(self, code: str) -> str: r"ModuleNotFoundError: No module named '(.*)'", output.content, ): - await self.codebox.ainstall(package.group(1)) + await self.ci_params.codebox.ainstall(package.group(1)) return f"{package.group(1)} was missing but got installed now. Please try again." else: # TODO: pre-analyze error to optimize next code generation pass - if self.verbose: + if self.ci_params.verbose: print("Error:", output.content) # elif modifications := await aget_file_modifications(code): @@ -180,19 +178,20 @@ async def _arun_handler(self, code: str) -> str: return output.content def show_code(self, code: str) -> None: - if self.verbose: + if self.ci_params.verbose: print(code) async def ashow_code(self, code: str) -> None: """Callback function to show code to the user.""" - if self.verbose: + if self.ci_params.verbose: print(code) def test(): settings.WORK_DIR = "/tmp" llm = prepare_test_llm() - tools_instance = PythonTools(llm=llm) + ci_params = CodeInterpreterParams.get_test_params(llm=llm) + tools_instance = PythonTools(ci_params=ci_params) test_code = "print('test output')" result = tools_instance._run_handler(test_code) print("result=", result) diff --git a/src/codeinterpreterapi/tools/tools.py b/src/codeinterpreterapi/tools/tools.py index f61d591b..fcbeae68 100644 --- a/src/codeinterpreterapi/tools/tools.py +++ b/src/codeinterpreterapi/tools/tools.py @@ -1,21 +1,20 @@ from langchain_community.tools.shell.tool import BaseTool, ShellTool from langchain_community.tools.tavily_search import TavilySearchResults -from langchain_core.language_models import BaseLanguageModel +from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.tools.python import PythonTools class CodeInterpreterTools: def __init__( self, - additional_tools: list[BaseTool], - llm: BaseLanguageModel, + ci_params: CodeInterpreterParams, ): - self._additional_tools = additional_tools - self._llm = llm + self._ci_params = ci_params + self._additional_tools = ci_params.tools def get_all_tools(self) -> list[BaseTool]: - self._additional_tools.extend(PythonTools.get_tools_python(self._llm)) + self._additional_tools.extend(PythonTools.get_tools_python(self._ci_params)) self.add_tools_shell() self.add_tools_web_search() return self._additional_tools