Skip to content

Commit

Permalink
fix: update prompts for LCEL version
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 21, 2024
1 parent d1e76a7 commit 4aa38e0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 27 deletions.
9 changes: 6 additions & 3 deletions src/codeinterpreterapi/planners/planners.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import List

from langchain import hub
from langchain.agents import create_react_agent
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain_core.runnables import Runnable

SYSTEM_PROMPT_PLANNER = """
Expand All @@ -25,7 +28,7 @@

class CodeInterpreterPlanner:
@staticmethod
def choose_planner(llm: BaseLanguageModel, is_ja: bool) -> Runnable:
def choose_planner(llm: BaseLanguageModel, tools: List[BaseTool], is_ja: bool) -> Runnable:
"""
Load a chat planner.
Expand All @@ -38,6 +41,6 @@ def choose_planner(llm: BaseLanguageModel, is_ja: bool) -> Runnable:
"""
system_prompt = SYSTEM_PROMPT_PLANNER_JA if is_ja else SYSTEM_PROMPT_PLANNER
print("system_prompt(planner)=", system_prompt)
prompt = hub.pull("nobu/chat_planner")
planner_agent = create_react_agent(llm, [], prompt)
prompt = hub.pull("nobu/simple_react")
planner_agent = create_react_agent(llm, tools, prompt)
return planner_agent
17 changes: 9 additions & 8 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from codeboxapi import CodeBox # type: ignore
from codeboxapi.schema import CodeBoxOutput # type: ignore
from gui_agent_loop_core.schema.schema import (
GuiAgentInterpreterChatResponseStr,
)
from gui_agent_loop_core.schema.schema import GuiAgentInterpreterChatResponseStr
from langchain.agents import AgentExecutor
from langchain.callbacks.base import Callbacks
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory
Expand Down Expand Up @@ -77,7 +76,7 @@ def __init__(
self.callbacks = callbacks
self.agent_executor: Optional[Runnable] = None
self.llm_planner: Optional[Runnable] = None
self.supervisor: Optional[Runnable] = None
self.supervisor: Optional[AgentExecutor] = None
self.input_files: list[File] = []
self.output_files: list[File] = []
self.code_log: list[tuple[str, str]] = []
Expand All @@ -96,7 +95,7 @@ def session_id(self) -> Optional[UUID]:
def initialize(self):
self.initialize_agent_executor()
self.initialize_llm_planner()
# self.initialize_supervisor()
self.initialize_supervisor()

def initialize_agent_executor(self):
is_experimental = False
Expand All @@ -118,12 +117,13 @@ def initialize_agent_executor(self):
)

def initialize_llm_planner(self):
self.llm_planner = CodeInterpreterPlanner.choose_planner(llm=self.llm, is_ja=self.is_ja)
self.llm_planner = CodeInterpreterPlanner.choose_planner(llm=self.llm, tools=self.tools, is_ja=self.is_ja)

def initialize_supervisor(self):
self.supervisor = CodeInterpreterSupervisor.choose_supervisor(
planner=self.llm_planner,
executor=self.agent_executor,
tools=self.tools,
verbose=self.verbose,
)

Expand Down Expand Up @@ -400,8 +400,9 @@ def generate_response(
input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad}

# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
response = self.agent_executor.invoke(input=input_message)
# response = self.supervisor.invoke(input=user_request)
# response = self.agent_executor.invoke(input=input_message)
# response = self.llm_planner.invoke(input=input_message)
response = self.supervisor.invoke(input=input_message)
# ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #=======
print("response(type)=", type(response))
print("response=", response)
Expand Down
24 changes: 8 additions & 16 deletions src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
import getpass
import os
import platform
from typing import List

from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain_core.runnables import RunnableAssign, RunnablePassthrough
from langchain_experimental.plan_and_execute.agent_executor import PlanAndExecute
from langchain_experimental.plan_and_execute.planners.base import LLMPlanner


class MySupervisorChain(Chain):
pass
from langchain.tools import BaseTool
from langchain_core.runnables import Runnable


class CodeInterpreterSupervisor:
@staticmethod
def choose_supervisor(planner: LLMPlanner, executor: AgentExecutor, verbose: bool = False) -> MySupervisorChain:
def choose_supervisor(
planner: Runnable, executor: Runnable, tools: List[BaseTool], verbose: bool = False
) -> AgentExecutor:
# prompt
username = getpass.getuser()
current_working_directory = os.getcwd()
operating_system = platform.system()
info = f"[User Info]\nName: {username}\nCWD: {current_working_directory}\nOS: {operating_system}"
print("choose_supervisor info=", info)

supervisor = PlanAndExecute(planner=planner, executor=executor, verbose=verbose)
agent_executor = AgentExecutor(agent=planner, tools=tools, verbose=verbose)
# prompt = hub.pull("nobu/chat_planner")
# agent = create_react_agent(llm, [], prompt)
# return agent
# prompt = hub.pull("nobu/code_writer:0c56967d")

supervisor_chain = RunnablePassthrough() | supervisor
return supervisor_chain

supervisor_chain = RunnableAssign() | supervisor
return supervisor_chain
return agent_executor

0 comments on commit 4aa38e0

Please sign in to comment.