forked from shroominic/codeinterpreter-api
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: update planners.py by using CodeInterpreterPlan
- Loading branch information
Showing
6 changed files
with
300 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate | ||
|
||
SYSTEM_MESSAGE_TEMPLATE = ''' | ||
You are a super agent manager. | ||
First of all, understand the problem. | ||
Please make a plan to solve it. | ||
You cau use AI agent here. | ||
{agent_info}. | ||
Constraints:. | ||
- Think step-by-step and be clear about the actions to be taken at each step. | ||
- If the required information is missing, answer with agent_name=None and ask the user in task_description. | ||
- If the required agent for execution does not exist, answer with agent_name=None and answer about the required agent in the task_description. | ||
- Once sufficient information is obtained to complete the task, the final work plan should be output. | ||
- The last step should be agent_name=<END_OF_PLAN>. | ||
Translated with DeepL.com (free version) | ||
''' | ||
|
||
SYSTEM_MESSAGE_TEMPLATE_JA = ''' | ||
あなたは優秀なAIエージェントを管理するシニアエンジニアです。 | ||
まず、問題を理解し、問題を解決するための計画を立てましょう。 | ||
利用可能なAI agentのリスト: | ||
{agent_info} | ||
制約条件: | ||
- 段階的に考え、各ステップで取るアクションを明確にすること。 | ||
- 必要な情報が不足している場合は、agent_name=Noneで回答し、task_description でユーザーに質問すること。 | ||
- 実行に必要なagentが存在しない場合は、agent_name=Noneで回答し、task_description で必要なagentについて回答すること。 | ||
- タスクを完了するために十分な情報が得られたら、最終的な作業計画を出力すること。 | ||
- 最後のステップはagent_name=<END_OF_PLAN>とすること。 | ||
- 各ステップの思考と出力は日本語とする。 | ||
''' | ||
|
||
|
||
def create_planner_agent_prompt(is_ja: bool = True) -> PromptTemplate: | ||
if is_ja: | ||
prompt = PromptTemplate( | ||
input_variables=["input", "agent_scratchpad", "agent_info"], template=SYSTEM_MESSAGE_TEMPLATE_JA | ||
) | ||
else: | ||
prompt = PromptTemplate( | ||
input_variables=["input", "agent_scratchpad", "agent_info"], template=SYSTEM_MESSAGE_TEMPLATE | ||
) | ||
return prompt | ||
|
||
|
||
def create_planner_agent_chat_prompt(is_ja: bool = True) -> ChatPromptTemplate: | ||
if is_ja: | ||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", SYSTEM_MESSAGE_TEMPLATE_JA), | ||
MessagesPlaceholder(variable_name="messages"), | ||
("user", "上記の会話を踏まえて、最終的な作業計画を日本語で出力してください。"), | ||
] | ||
) | ||
else: | ||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", SYSTEM_MESSAGE_TEMPLATE), | ||
MessagesPlaceholder(variable_name="messages"), | ||
("user", "Given the conversation above, please output final plan."), | ||
] | ||
) | ||
return prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from langchain_core.prompts import ( | ||
AIMessagePromptTemplate, | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
MessagesPlaceholder, | ||
PromptTemplate, | ||
SystemMessagePromptTemplate, | ||
) | ||
|
||
from codeinterpreterapi.brain.params import CodeInterpreterParams | ||
|
||
|
||
class PromptUpdater: | ||
@staticmethod | ||
def show_prompt(prompt: PromptTemplate): | ||
prompt_str = prompt.format() | ||
print(f"show_prompt: {prompt_str}") | ||
|
||
@staticmethod | ||
def show_chat_prompt(prompt: ChatPromptTemplate): | ||
print("show_chat_prompt:") | ||
for message in prompt.messages: | ||
if isinstance(message, SystemMessagePromptTemplate): | ||
print(f"System: {PromptUpdater.show_prompt(message.prompt)}") | ||
elif isinstance(message, HumanMessagePromptTemplate): | ||
print(f"Human: {message.prompt.template}") | ||
elif isinstance(message, AIMessagePromptTemplate): | ||
print(f"AI: {message.prompt.template}") | ||
elif isinstance(message, MessagesPlaceholder): | ||
print(f"MessagesPlaceholder: {message.variable_name}") | ||
else: | ||
print(f"Other: {message}") | ||
|
||
@staticmethod | ||
def update_prompt(prompt: PromptTemplate, ci_params: CodeInterpreterParams) -> PromptTemplate: | ||
# Check if the prompt has the required input variables | ||
missing_vars = {"agent_info"}.difference(prompt.input_variables + list(prompt.partial_variables)) | ||
if missing_vars: | ||
raise ValueError(f"Prompt missing required variables: {missing_vars}") | ||
|
||
# Partial the prompt with tools and tool_names | ||
prompt = prompt.partial( | ||
agent_info=", ".join([agent_def.get_agent_info() for agent_def in ci_params.agent_def_list]), | ||
) | ||
return prompt | ||
|
||
@staticmethod | ||
def update_chat_prompt(prompt: ChatPromptTemplate, ci_params: CodeInterpreterParams) -> ChatPromptTemplate: | ||
# Check if the prompt has the required input variables | ||
missing_vars = {"agent_info"}.difference(prompt.input_variables) | ||
if missing_vars: | ||
raise ValueError(f"Prompt missing required variables: {missing_vars}") | ||
|
||
# Create a new ChatPromptTemplate with updated messages | ||
updated_messages = [] | ||
for message in prompt.messages: | ||
if isinstance(message, SystemMessagePromptTemplate): | ||
updated_content = PromptUpdater.update_prompt(message.prompt, ci_params) | ||
updated_messages.append(SystemMessagePromptTemplate(prompt=updated_content)) | ||
elif isinstance(message, (HumanMessagePromptTemplate, AIMessagePromptTemplate)): | ||
updated_messages.append(message) | ||
elif isinstance(message, MessagesPlaceholder): | ||
updated_messages.append(message) | ||
else: | ||
raise ValueError(f"Unexpected message type: {type(message)}") | ||
|
||
return ChatPromptTemplate(messages=updated_messages) | ||
|
||
@staticmethod | ||
def update_and_show_chat_prompt(prompt: ChatPromptTemplate, ci_params: CodeInterpreterParams) -> ChatPromptTemplate: | ||
updated_prompt = PromptUpdater.update_chat_prompt(prompt, ci_params) | ||
PromptUpdater.show_chat_prompt(updated_prompt) | ||
return updated_prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.