Skip to content

Commit

Permalink
fix: update planners.py by using CodeInterpreterPlan
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Jul 28, 2024
1 parent e4672c0 commit 59052ea
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 90 deletions.
7 changes: 5 additions & 2 deletions src/codeinterpreterapi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,13 @@ def get_llms(cls, model: str = settings.MODEL_LOCAL) -> List[BaseChatModel]:
return llms


def prepare_test_llm():
def prepare_test_llm(is_smart: bool = False):
load_dotenv(verbose=True, override=False)

llm = CodeInterpreterLlm.get_llm_switcher()
if is_smart:
llm = CodeInterpreterLlm.get_llm_smart()
else:
llm = CodeInterpreterLlm.get_llm_switcher()
llm_tools = CodeInterpreterLlm.get_llm_switcher_tools()
callbacks = [FullOutCallbackHandler()]
configurable = {"session_id": "123"}
Expand Down
117 changes: 50 additions & 67 deletions src/codeinterpreterapi/planners/planners.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,62 @@
from langchain import PromptTemplate, hub
from typing import List, Union

from langchain.agents import AgentExecutor
from langchain.agents.agent import RunnableAgent
from langchain.tools.render import render_text_description
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.utils.output_parser import PlannerSingleOutputParser
from codeinterpreterapi.planners.prompts import create_planner_agent_chat_prompt, create_planner_agent_prompt
from codeinterpreterapi.test_prompts.test_prompt import TestPrompt
from codeinterpreterapi.utils.prompt import PromptUpdater
from codeinterpreterapi.utils.runnable import create_complement_input


class CodeInterpreterPlan(BaseModel):
'''Agent and Task definition. Plan and task is 1:1.'''

agent_name: str = Field(
description="The agent name for task. This is primary key. Agent responsible for task execution. Represents entity performing task."
)
task_description: str = Field(description="Descriptive text detailing task's purpose and execution.")
expected_output: str = Field(description="Clear definition of expected task outcome.")


class CodeInterpreterPlanList(BaseModel):
'''Sequential plans for the task.'''

agent_task_list: List[CodeInterpreterPlan] = Field(
description="The list of CodeInterpreterPlan. It means agent name and so on."
)


class CodeInterpreterPlanner:
@staticmethod
def choose_planner(ci_params: CodeInterpreterParams) -> Runnable:
def choose_planner(ci_params: CodeInterpreterParams) -> Union[Runnable, AgentExecutor]:
"""
Load a chat planner.
Args:
llm: Language model.
tools: List of tools this agent has access to.
is_ja: System prompt.
Returns:
LLMPlanner
<prompt: simple_react>
Input
tools:
tool_names:
input:
agent_scratchpad:
Output
content: Free text in str.
"""
# prompt_name_react = "nobu/simple_react"
prompt_name = "nobu/chat_planner"
if ci_params.is_ja:
prompt_name += "_ja"
prompt = hub.pull(prompt_name)
# prompt = CodeInterpreterPlanner.get_prompt()

is_chat_prompt = True
if is_chat_prompt:
prompt = create_planner_agent_chat_prompt()
prompt = PromptUpdater.update_and_show_chat_prompt(prompt, ci_params)
else:
prompt = create_planner_agent_prompt()
prompt = PromptUpdater.update_prompt(prompt, ci_params)
PromptUpdater.show_prompt(prompt)

# structured_llm
structured_llm = ci_params.llm.bind_tools(tools=[CodeInterpreterPlanList])

# runnable
runnable = (
create_complement_input(prompt)
| prompt
| ci_params.llm
create_complement_input(prompt) | prompt | structured_llm
# | StrOutputParser()
| PlannerSingleOutputParser()
# | PlannerSingleOutputParser()
)
# runnable = assign_runnable_history(runnable, ci_params.runnable_config)

Expand All @@ -56,50 +66,23 @@ def choose_planner(ci_params: CodeInterpreterParams) -> Runnable:
remapped_inputs = create_complement_input(prompt).invoke({})
agent = RunnableAgent(runnable=runnable, input_keys=list(remapped_inputs.keys()))

# agent_executor
agent_executor = AgentExecutor(agent=agent, tools=[], verbose=ci_params.verbose)

return agent_executor

@staticmethod
def update_prompt(prompt: PromptTemplate, ci_params: CodeInterpreterParams) -> PromptTemplate:
# Check if the prompt has the required input variables
missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
prompt.input_variables + list(prompt.partial_variables)
)
if missing_vars:
raise ValueError(f"Prompt missing required variables: {missing_vars}")

# Partial the prompt with tools and tool_names
prompt = prompt.partial(
tools=render_text_description(list(ci_params.tools)),
tool_names=", ".join([t.name for t in ci_params.tools]),
)
return prompt

@staticmethod
def get_prompt():
prompt_template = """
あなたは親切なアシスタントです。以下の制約条件を守ってタスクを完了させてください。
制約条件:
- 段階的に考え、各ステップで取るアクションを明確にすること。
- 必要な情報が不足している場合は、ユーザーに質問すること。
- タスクを完了するために十分な情報が得られたら、最終的な回答を出力すること。
# return executor or runnable
return_as_executor = False
if return_as_executor:
# TODO: handle step by step by original OutputParser
agent_executor = AgentExecutor(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose)
return agent_executor

タスク: {input}
agent_scratchpad: {agent_scratchpad}
"""
prompt = PromptTemplate(input_variables=["input", "agent_scratchpad"], template=prompt_template)
return prompt
else:
return runnable


def test():
sample = "ステップバイステップで2*5+2を計算して。"
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
_ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params)
result = planner.invoke({"input": sample, "agent_scratchpad": ""})
result = planner.invoke({"input": "", "agent_scratchpad": "", "messages": [TestPrompt.svg_input_str]})
print("result=", result)


Expand Down
68 changes: 68 additions & 0 deletions src/codeinterpreterapi/planners/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate

SYSTEM_MESSAGE_TEMPLATE = '''
You are a super agent manager.
First of all, understand the problem.
Please make a plan to solve it.
You cau use AI agent here.
{agent_info}.
Constraints:.
- Think step-by-step and be clear about the actions to be taken at each step.
- If the required information is missing, answer with agent_name=None and ask the user in task_description.
- If the required agent for execution does not exist, answer with agent_name=None and answer about the required agent in the task_description.
- Once sufficient information is obtained to complete the task, the final work plan should be output.
- The last step should be agent_name=<END_OF_PLAN>.
Translated with DeepL.com (free version)
'''

SYSTEM_MESSAGE_TEMPLATE_JA = '''
あなたは優秀なAIエージェントを管理するシニアエンジニアです。
まず、問題を理解し、問題を解決するための計画を立てましょう。
利用可能なAI agentのリスト:
{agent_info}
制約条件:
- 段階的に考え、各ステップで取るアクションを明確にすること。
- 必要な情報が不足している場合は、agent_name=Noneで回答し、task_description でユーザーに質問すること。
- 実行に必要なagentが存在しない場合は、agent_name=Noneで回答し、task_description で必要なagentについて回答すること。
- タスクを完了するために十分な情報が得られたら、最終的な作業計画を出力すること。
- 最後のステップはagent_name=<END_OF_PLAN>とすること。
- 各ステップの思考と出力は日本語とする。
'''


def create_planner_agent_prompt(is_ja: bool = True) -> PromptTemplate:
if is_ja:
prompt = PromptTemplate(
input_variables=["input", "agent_scratchpad", "agent_info"], template=SYSTEM_MESSAGE_TEMPLATE_JA
)
else:
prompt = PromptTemplate(
input_variables=["input", "agent_scratchpad", "agent_info"], template=SYSTEM_MESSAGE_TEMPLATE
)
return prompt


def create_planner_agent_chat_prompt(is_ja: bool = True) -> ChatPromptTemplate:
if is_ja:
prompt = ChatPromptTemplate.from_messages(
[
("system", SYSTEM_MESSAGE_TEMPLATE_JA),
MessagesPlaceholder(variable_name="messages"),
("user", "上記の会話を踏まえて、最終的な作業計画を日本語で出力してください。"),
]
)
else:
prompt = ChatPromptTemplate.from_messages(
[
("system", SYSTEM_MESSAGE_TEMPLATE),
MessagesPlaceholder(variable_name="messages"),
("user", "Given the conversation above, please output final plan."),
]
)
return prompt
73 changes: 73 additions & 0 deletions src/codeinterpreterapi/utils/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from langchain_core.prompts import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)

from codeinterpreterapi.brain.params import CodeInterpreterParams


class PromptUpdater:
@staticmethod
def show_prompt(prompt: PromptTemplate):
prompt_str = prompt.format()
print(f"show_prompt: {prompt_str}")

@staticmethod
def show_chat_prompt(prompt: ChatPromptTemplate):
print("show_chat_prompt:")
for message in prompt.messages:
if isinstance(message, SystemMessagePromptTemplate):
print(f"System: {PromptUpdater.show_prompt(message.prompt)}")
elif isinstance(message, HumanMessagePromptTemplate):
print(f"Human: {message.prompt.template}")
elif isinstance(message, AIMessagePromptTemplate):
print(f"AI: {message.prompt.template}")
elif isinstance(message, MessagesPlaceholder):
print(f"MessagesPlaceholder: {message.variable_name}")
else:
print(f"Other: {message}")

@staticmethod
def update_prompt(prompt: PromptTemplate, ci_params: CodeInterpreterParams) -> PromptTemplate:
# Check if the prompt has the required input variables
missing_vars = {"agent_info"}.difference(prompt.input_variables + list(prompt.partial_variables))
if missing_vars:
raise ValueError(f"Prompt missing required variables: {missing_vars}")

# Partial the prompt with tools and tool_names
prompt = prompt.partial(
agent_info=", ".join([agent_def.get_agent_info() for agent_def in ci_params.agent_def_list]),
)
return prompt

@staticmethod
def update_chat_prompt(prompt: ChatPromptTemplate, ci_params: CodeInterpreterParams) -> ChatPromptTemplate:
# Check if the prompt has the required input variables
missing_vars = {"agent_info"}.difference(prompt.input_variables)
if missing_vars:
raise ValueError(f"Prompt missing required variables: {missing_vars}")

# Create a new ChatPromptTemplate with updated messages
updated_messages = []
for message in prompt.messages:
if isinstance(message, SystemMessagePromptTemplate):
updated_content = PromptUpdater.update_prompt(message.prompt, ci_params)
updated_messages.append(SystemMessagePromptTemplate(prompt=updated_content))
elif isinstance(message, (HumanMessagePromptTemplate, AIMessagePromptTemplate)):
updated_messages.append(message)
elif isinstance(message, MessagesPlaceholder):
updated_messages.append(message)
else:
raise ValueError(f"Unexpected message type: {type(message)}")

return ChatPromptTemplate(messages=updated_messages)

@staticmethod
def update_and_show_chat_prompt(prompt: ChatPromptTemplate, ci_params: CodeInterpreterParams) -> ChatPromptTemplate:
updated_prompt = PromptUpdater.update_chat_prompt(prompt, ci_params)
PromptUpdater.show_chat_prompt(updated_prompt)
return updated_prompt
30 changes: 9 additions & 21 deletions tests/one_test.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
from codeinterpreterapi import CodeInterpreterSession
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.planners.planners import CodeInterpreterPlanner


def test():
model = "claude-3-haiku-20240307"
# model = "gemini-1.0-pro"
test_message = "pythonで円周率を表示するプログラムを実行してください。"
verbose = False
is_streaming = False
print("test_message=", test_message)
session = CodeInterpreterSession(model=model, verbose=verbose)
status = session.start_local()
print("status=", status)
if is_streaming:
# response_inner: CodeInterpreterResponse
response_inner = session.generate_response_stream(test_message)
for response_inner_chunk in response_inner:
print("response_inner_chunk.content=", response_inner.content)
print("response_inner_chunk code_log=", response_inner.code_log)
else:
# response_inner: CodeInterpreterResponse
response_inner = session.generate_response(test_message)
print("response_inner.content=", response_inner.content)
print("response_inner code_log=", response_inner.code_log)
sample = "ステップバイステップで2*5+2を計算して。"
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params)
result = planner.invoke({"input": sample, "agent_scratchpad": "", "messages": [sample]})
print("result=", result.content)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 59052ea

Please sign in to comment.