Skip to content

Commit

Permalink
fix: add CodeInterpreterParams
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 29, 2024
1 parent f0a9b74 commit 7ac48da
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 134 deletions.
2 changes: 2 additions & 0 deletions src/codeinterpreterapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.config import settings
from codeinterpreterapi.schema import File
from codeinterpreterapi.session import CodeInterpreterSession

from . import _patch_parser # noqa

__all__ = [
"CodeInterpreterParams",
"CodeInterpreterSession",
"File",
"settings",
Expand Down
28 changes: 14 additions & 14 deletions src/codeinterpreterapi/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,40 @@
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from codeinterpreterapi.agents.plan_and_execute.agent_executor import load_agent_executor
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.config import settings


class CodeInterpreterAgent:
@staticmethod
def create_agent_executor_lcel(llm, tools, verbose=False, chat_memory=None, callbacks=None, is_ja=True) -> Runnable:
def create_agent_executor_lcel(ci_params: CodeInterpreterParams) -> Runnable:
# prompt
prompt = hub.pull("hwchase17/openai-functions-agent")

# agent
agent = create_tool_calling_agent(llm, tools, prompt)
agent = create_tool_calling_agent(ci_params.llm, ci_params.tools, prompt)

# agent_executor
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=verbose,
tools=ci_params.tools,
verbose=ci_params.verbose,
# memory=ConversationBufferMemory(
# memory_key="chat_history",
# return_messages=True,
# chat_memory=chat_memory,
# ),
callbacks=callbacks,
callbacks=ci_params.callbacks,
)
print("agent_executor.input_keys", agent_executor.input_keys)
print("agent_executor.output_keys", agent_executor.output_keys)
return agent_executor

@staticmethod
def choose_single_chat_agent(
llm,
tools,
is_ja,
) -> BaseSingleActionAgent:
def choose_single_chat_agent(ci_params: CodeInterpreterParams) -> BaseSingleActionAgent:
llm = ci_params.llm
tools = ci_params.tools
is_ja = ci_params.is_ja
system_message = settings.SYSTEM_MESSAGE if is_ja else settings.SYSTEM_MESSAGE_JA
if isinstance(llm, ChatOpenAI) or isinstance(llm, AzureChatOpenAI):
print("choose_agent OpenAIFunctionsAgent")
Expand Down Expand Up @@ -80,16 +80,16 @@ def choose_single_chat_agent(
)

@staticmethod
def create_agent_and_executor(llm, tools, verbose, chat_memory, callbacks, is_ja=True) -> AgentExecutor:
def create_agent_and_executor(ci_params: CodeInterpreterParams) -> AgentExecutor:
# agent
agent = CodeInterpreterAgent.choose_single_chat_agent(llm, tools, is_ja=is_ja)
agent = CodeInterpreterAgent.choose_single_chat_agent(ci_params)
print("create_agent_and_executor agent=", str(type(agent)))
return agent

@staticmethod
def create_agent_and_executor_experimental(llm, tools, verbose, is_ja) -> AgentExecutor:
def create_agent_and_executor_experimental(ci_params: CodeInterpreterParams) -> AgentExecutor:
# agent_executor
agent_executor = load_agent_executor(llm, tools, verbose=verbose, is_ja=is_ja)
agent_executor = load_agent_executor(ci_params)
print("create_agent_and_executor_experimental")

return agent_executor
29 changes: 9 additions & 20 deletions src/codeinterpreterapi/agents/plan_and_execute/agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
from typing import List, Optional

from langchain.agents.agent import AgentExecutor, AgentOutputParser
from langchain.agents.agent import AgentExecutor
from langchain.agents.structured_chat.base import create_structured_chat_agent
from langchain.tools import BaseTool
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel

from codeinterpreterapi.agents.plan_and_execute.prompts import create_structured_chat_agent_prompt
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm


def load_agent_executor(
llm: BaseLanguageModel,
tools: List[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
verbose: bool = False,
include_task_in_prompt: bool = False,
is_ja: str = True,
) -> AgentExecutor:
def load_agent_executor(ci_params: CodeInterpreterParams) -> AgentExecutor:
"""
Load an agent executor.
Expand All @@ -33,11 +21,11 @@ def load_agent_executor(
"""
input_variables = ["previous_steps", "current_step", "agent_scratchpad", "tools", "tool_names"]
print("input_variables=", input_variables)
prompt = create_structured_chat_agent_prompt(is_ja)
prompt = create_structured_chat_agent_prompt(ci_params.is_ja)
print("prompt=", prompt.get_prompts())
agent = create_structured_chat_agent(
llm=llm,
tools=tools,
llm=ci_params.llm,
tools=ci_params.tools,
# callback_manager=callback_manager,
# output_parser=output_parser,
# prefix=tools_prefix,
Expand All @@ -48,13 +36,14 @@ def load_agent_executor(
# memory_prompts = memory_prompts,
)

agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=verbose)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose)
return agent_executor


def test():
llm = prepare_test_llm()
agent_executor = load_agent_executor(llm=llm, tools=[])
ci_params = CodeInterpreterParams.get_test_params(llm=llm)
agent_executor = load_agent_executor(ci_params)
test_input = "pythonで円周率を表示するプログラムを実行してください。"
agent_executor_output = agent_executor.invoke({"input": test_input})
print("agent_executor_output=", agent_executor_output)
Expand Down
22 changes: 14 additions & 8 deletions src/codeinterpreterapi/planners/planners.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import List

from langchain import hub
from langchain.agents import create_react_agent
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain_core.runnables import Runnable

from codeinterpreterapi.brain.params import CodeInterpreterParams


class CodeInterpreterPlanner:
@staticmethod
def choose_planner(llm: BaseLanguageModel, tools: List[BaseTool], is_ja: bool) -> Runnable:
def choose_planner(ci_params: CodeInterpreterParams) -> Runnable:
"""
Load a chat planner.
Expand All @@ -20,11 +18,19 @@ def choose_planner(llm: BaseLanguageModel, tools: List[BaseTool], is_ja: bool) -
Returns:
LLMPlanner
"""
<prompt: simple_react>
Input
tools:
tool_names:
input:
agent_scratchpad:
Output
content: Free text in str.
"""
prompt_name = "nobu/simple_react"
if is_ja:
if ci_params.is_ja:
prompt_name += "_ja"
prompt = hub.pull(prompt_name)
planner_agent = create_react_agent(llm, tools, prompt)
planner_agent = create_react_agent(ci_params.llm, ci_params.tools, prompt)
return planner_agent
68 changes: 12 additions & 56 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,22 @@

from codeboxapi import CodeBox # type: ignore
from gui_agent_loop_core.schema.schema import GuiAgentInterpreterChatResponseStr
from langchain.agents import AgentExecutor
from langchain.callbacks.base import Callbacks
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain_community.chat_message_histories.redis import RedisChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
from codeinterpreterapi.brain.brain import CodeInterpreterBrain
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.chains import aremove_download_link, remove_download_link
from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory
from codeinterpreterapi.config import settings
from codeinterpreterapi.llm.llm import CodeInterpreterLlm
from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest
from codeinterpreterapi.thoughts.thoughts import CodeInterpreterToT

from .planners.planners import CodeInterpreterPlanner
from .supervisors.supervisors import CodeInterpreterSupervisor
from .tools.tools import CodeInterpreterTools


def _handle_deprecated_kwargs(kwargs: dict) -> None:
Expand All @@ -49,19 +44,19 @@ def __init__(
**kwargs: Any,
) -> None:
_handle_deprecated_kwargs(kwargs)
self.is_local = is_local
self.is_ja = is_ja
self.codebox = CodeBox(requirements=settings.CUSTOM_PACKAGES)
self.verbose = kwargs.get("verbose", settings.DEBUG)
self.llm: BaseLanguageModel = llm or CodeInterpreterLlm.get_llm()
self.tools: list[BaseTool] = CodeInterpreterTools(additional_tools, self.llm).get_all_tools()
self.llm: BaseLanguageModel = llm or CodeInterpreterLlm.get_llm() # TODO: remove from session
ci_params = CodeInterpreterParams(
llm=self.llm,
tools=additional_tools,
callbacks=callbacks,
verbose=self.verbose,
is_local=is_local,
is_ja=is_ja,
)
self.brain = CodeInterpreterBrain(ci_params)
self.log("self.llm=" + str(self.llm))

self.callbacks = callbacks
self.agent_executor: Optional[Runnable] = None
self.llm_planner: Optional[Runnable] = None
self.supervisor: Optional[AgentExecutor] = None
self.thought: Optional[Runnable] = None
self.input_files: list[File] = []
self.output_files: list[File] = []
self.code_log: list[tuple[str, str]] = []
Expand All @@ -77,45 +72,6 @@ def from_id(cls, session_id: UUID, **kwargs: Any) -> "CodeInterpreterSession":
def session_id(self) -> Optional[UUID]:
return self.codebox.session_id

def initialize(self):
self.initialize_agent_executor()
self.initialize_llm_planner()
self.initialize_supervisor()
self.initialize_thought()

def initialize_agent_executor(self):
is_experimental = True
if is_experimental:
self.agent_executor = CodeInterpreterAgent.create_agent_and_executor_experimental(
llm=self.llm,
tools=self.tools,
verbose=self.verbose,
is_ja=self.is_ja,
)
else:
self.agent_executor = CodeInterpreterAgent.create_agent_executor_lcel(
llm=self.llm,
tools=self.tools,
verbose=self.verbose,
is_ja=self.is_ja,
chat_memory=self._history_backend(),
callbacks=self.callbacks,
)

def initialize_llm_planner(self):
self.llm_planner = CodeInterpreterPlanner.choose_planner(llm=self.llm, tools=self.tools, is_ja=self.is_ja)

def initialize_supervisor(self):
self.supervisor = CodeInterpreterSupervisor.choose_supervisor(
planner=self.llm_planner,
executor=self.agent_executor,
tools=self.tools,
verbose=self.verbose,
)

def initialize_thought(self):
self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm, is_ja=self.is_ja, is_simple=False)

def start(self) -> SessionStatus:
print("start")
status = SessionStatus.from_codebox_status(self.codebox.start())
Expand Down
10 changes: 4 additions & 6 deletions src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import getpass
import os
import platform
from typing import List

from langchain.agents import AgentExecutor
from langchain.tools import BaseTool
from langchain_core.runnables import Runnable

from codeinterpreterapi.brain.params import CodeInterpreterParams


class CodeInterpreterSupervisor:
@staticmethod
def choose_supervisor(
planner: Runnable, executor: Runnable, tools: List[BaseTool], verbose: bool = False
) -> AgentExecutor:
def choose_supervisor(planner: Runnable, ci_params: CodeInterpreterParams) -> AgentExecutor:
# prompt
username = getpass.getuser()
current_working_directory = os.getcwd()
operating_system = platform.system()
info = f"[User Info]\nName: {username}\nCWD: {current_working_directory}\nOS: {operating_system}"
print("choose_supervisor info=", info)
agent_executor = AgentExecutor(agent=planner, tools=tools, verbose=verbose)
agent_executor = AgentExecutor(agent=planner, tools=ci_params.tools, verbose=ci_params.verbose)
# prompt = hub.pull("nobu/chat_planner")
# agent = create_react_agent(llm, [], prompt)
# return agent
Expand Down
17 changes: 11 additions & 6 deletions src/codeinterpreterapi/thoughts/thoughts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Dict, List, Optional, Union

from codeinterpreterapi.thoughts.base import MyToTChain
from codeinterpreterapi.thoughts.checker import create_tot_chain_from_llm
from langchain_core.runnables import RunnableSerializable
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import Input, Output

from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.thoughts.base import MyToTChain
from codeinterpreterapi.thoughts.checker import create_tot_chain_from_llm


class CodeInterpreterToT(RunnableSerializable):
tot_chain: MyToTChain = None
Expand All @@ -14,7 +17,7 @@ def __init__(self, llm=None, is_ja=True, is_simple=False):
super().__init__()
self.tot_chain = create_tot_chain_from_llm(llm=llm, is_ja=is_ja, is_simple=is_simple)

def run(self, input: Input):
def run(self, input: Input) -> Output:
problem_description = input["input"]
return self.tot_chain.run(problem_description=problem_description)

Expand All @@ -41,14 +44,16 @@ async def abatch(
raise NotImplementedError("Async not implemented yet")

@classmethod
def get_runnable_tot_chain(cls, llm=None, is_ja=True, is_simple=False):
def get_runnable_tot_chain(cls, ci_params: CodeInterpreterParams, is_simple: bool = False):
# ToTChainのインスタンスを作成
tot_chain = cls(llm=llm, is_ja=is_ja, is_simple=is_simple)
tot_chain = cls(llm=ci_params.llm, is_ja=ci_params.is_ja, is_simple=is_simple)
return tot_chain


def test():
tot_chain = CodeInterpreterToT.get_runnable_tot_chain(is_simple=True)
llm = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm)
tot_chain = CodeInterpreterToT.get_runnable_tot_chain(ci_params=ci_params)
tot_chain.invoke({"input": sample2})


Expand Down
Loading

0 comments on commit 7ac48da

Please sign in to comment.