From da523b8aefba5f858d899cbd863d6a981b1b405d Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Fri, 15 Dec 2023 18:05:14 +0800 Subject: [PATCH 01/13] init implementation --- docs/example.md | 2 +- project/plugins/anomaly_detection.yaml | 2 +- project/plugins/ascii_render.py | 12 ++ project/plugins/ascii_render.yaml | 18 +++ project/plugins/klarna_search.yaml | 2 +- project/plugins/paper_summary.yaml | 2 +- project/plugins/sql_pull_data.yaml | 2 +- project/plugins/tell_joke.py | 8 ++ project/plugins/tell_joke.yaml | 17 +++ taskweaver/code_interpreter/__init__.py | 1 + .../code_generator/__init__.py | 1 + .../code_generator/code_generator.py | 2 +- .../code_generator_plugin_only.py | 131 ++++++++++++++++++ ...prompt.yaml => code_generator_prompt.yaml} | 0 .../code_generator_prompt_plugin_only.yaml | 4 + .../code_interpreter_plugin_only.py | 113 +++++++++++++++ taskweaver/llm/__init__.py | 4 + taskweaver/llm/base.py | 1 + taskweaver/llm/openai.py | 21 +++ taskweaver/memory/plugin.py | 30 ++++ taskweaver/session/session.py | 4 +- 21 files changed, 369 insertions(+), 8 deletions(-) create mode 100644 project/plugins/ascii_render.py create mode 100644 project/plugins/ascii_render.yaml create mode 100644 project/plugins/tell_joke.py create mode 100644 project/plugins/tell_joke.yaml create mode 100644 taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py rename taskweaver/code_interpreter/code_generator/{code_generator_json_prompt.yaml => code_generator_prompt.yaml} (100%) create mode 100644 taskweaver/code_interpreter/code_generator/code_generator_prompt_plugin_only.yaml create mode 100644 taskweaver/code_interpreter/code_interpreter_plugin_only.py diff --git a/docs/example.md b/docs/example.md index 9aae4a71..461d390f 100644 --- a/docs/example.md +++ b/docs/example.md @@ -79,7 +79,7 @@ rounds: A code interpreter example tells LLMs how to generate code or orchestrate plugins to perform a specific task. The task is from the planner. Before constructing the code interpreter example, we strongly encourage you to -read the [code generator prompt](../taskweaver/code_interpreter/code_generator/code_generator_json_prompt.yaml). +read the [code generator prompt](../taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml). The following is an example of a code interpreter example which contains 2 posts. Each post contains a message, a sender, a receiver, and a list of attachments. diff --git a/project/plugins/anomaly_detection.yaml b/project/plugins/anomaly_detection.yaml index 29c68cdc..2f8a3393 100644 --- a/project/plugins/anomaly_detection.yaml +++ b/project/plugins/anomaly_detection.yaml @@ -1,5 +1,5 @@ name: anomaly_detection -enabled: true +enabled: false required: false description: >- anomaly_detection function identifies anomalies from an input DataFrame of diff --git a/project/plugins/ascii_render.py b/project/plugins/ascii_render.py new file mode 100644 index 00000000..23a8a752 --- /dev/null +++ b/project/plugins/ascii_render.py @@ -0,0 +1,12 @@ +from taskweaver.plugin import Plugin, register_plugin + + +@register_plugin +class AsciiRenderPlugin(Plugin): + def __call__(self, text: str): + import pyfiglet + + ASCII_art_1 = pyfiglet.figlet_format(text, font="isometric1") + result = ASCII_art_1 + + return result diff --git a/project/plugins/ascii_render.yaml b/project/plugins/ascii_render.yaml new file mode 100644 index 00000000..76f1dd9e --- /dev/null +++ b/project/plugins/ascii_render.yaml @@ -0,0 +1,18 @@ +name: ascii_render +enabled: true +required: true +description: >- + This plugin renders the input text into ASCII art form. The input should be a string and the output is also a string in ASCII art. + +parameters: + - name: text + type: str + required: true + description: >- + This is the input text to be rendered into ASCII art form. + +returns: + - name: result + type: str + description: >- + The rendered text in ASCII art. \ No newline at end of file diff --git a/project/plugins/klarna_search.yaml b/project/plugins/klarna_search.yaml index 18907092..013379bf 100644 --- a/project/plugins/klarna_search.yaml +++ b/project/plugins/klarna_search.yaml @@ -1,5 +1,5 @@ name: klarna_search -enabled: true +enabled: false required: false description: >- Search and compare prices from thousands of online shops. Only available in the US. diff --git a/project/plugins/paper_summary.yaml b/project/plugins/paper_summary.yaml index 4a07306a..c73b1547 100644 --- a/project/plugins/paper_summary.yaml +++ b/project/plugins/paper_summary.yaml @@ -1,5 +1,5 @@ name: paper_summary -enabled: true +enabled: false required: false description: >- summarize_paper function iteratively summarizes a given paper page by page, diff --git a/project/plugins/sql_pull_data.yaml b/project/plugins/sql_pull_data.yaml index 2dc3ff14..de91afb9 100644 --- a/project/plugins/sql_pull_data.yaml +++ b/project/plugins/sql_pull_data.yaml @@ -1,5 +1,5 @@ name: sql_pull_data -enabled: true +enabled: false required: false description: >- Pull data from a SQL database. This plugin takes user requests when obtaining data from database is explicitly mentioned. diff --git a/project/plugins/tell_joke.py b/project/plugins/tell_joke.py new file mode 100644 index 00000000..d8b8866d --- /dev/null +++ b/project/plugins/tell_joke.py @@ -0,0 +1,8 @@ +from taskweaver.plugin import Plugin, register_plugin + + +@register_plugin +class TellJoke(Plugin): + def __call__(self, context: str): + # Define the API endpoint and parameters + return " Why don't cats play poker in the jungle? Too many cheetahs!" diff --git a/project/plugins/tell_joke.yaml b/project/plugins/tell_joke.yaml new file mode 100644 index 00000000..e563171e --- /dev/null +++ b/project/plugins/tell_joke.yaml @@ -0,0 +1,17 @@ +name: tell_joke +enabled: true +required: false +description: >- + Call this plugin to tell a joke. + +parameters: + - name: context + type: str + required: true + description: the context of the joke. + + +returns: + - name: joke + type: str + description: the joke. diff --git a/taskweaver/code_interpreter/__init__.py b/taskweaver/code_interpreter/__init__.py index c458f384..34ef27d8 100644 --- a/taskweaver/code_interpreter/__init__.py +++ b/taskweaver/code_interpreter/__init__.py @@ -1 +1,2 @@ from .code_interpreter import CodeInterpreter +from .code_interpreter_plugin_only import CodeInterpreterPluginOnly diff --git a/taskweaver/code_interpreter/code_generator/__init__.py b/taskweaver/code_interpreter/code_generator/__init__.py index d1cde92c..2437a1d5 100644 --- a/taskweaver/code_interpreter/code_generator/__init__.py +++ b/taskweaver/code_interpreter/code_generator/__init__.py @@ -1,2 +1,3 @@ from .code_generator import CodeGenerator, CodeGeneratorConfig, format_code_revision_message +from .code_generator_plugin_only import CodeGeneratorPluginOnly from .code_verification import CodeVerificationConfig, code_snippet_verification, format_code_correction_message diff --git a/taskweaver/code_interpreter/code_generator/code_generator.py b/taskweaver/code_interpreter/code_generator/code_generator.py index 4c5416fd..fd84464f 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator.py +++ b/taskweaver/code_interpreter/code_generator/code_generator.py @@ -27,7 +27,7 @@ def _configure(self) -> None: "prompt_file_path", os.path.join( os.path.dirname(os.path.abspath(__file__)), - "code_generator_json_prompt.yaml", + "code_generator_prompt.yaml", ), ) self.example_base_path = self._get_path( diff --git a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py new file mode 100644 index 00000000..80457fee --- /dev/null +++ b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py @@ -0,0 +1,131 @@ +import os +from typing import List, Tuple + +from injector import inject + +from taskweaver.code_interpreter.code_generator import CodeGeneratorConfig +from taskweaver.code_interpreter.code_generator.plugin_selection import PluginSelector, SelectedPluginPool +from taskweaver.config.module_config import ModuleConfig +from taskweaver.llm import LLMApi, format_chat_message +from taskweaver.logging import TelemetryLogger +from taskweaver.memory import Attachment, Memory, Post, Round +from taskweaver.memory.plugin import PluginEntry, PluginRegistry +from taskweaver.role import PostTranslator, Role +from taskweaver.utils import read_yaml + + +class CodeGeneratorPluginOnlyConfig(ModuleConfig): + def _configure(self) -> None: + self._set_name("code_generator") + self.role_name = self._get_str("role_name", "ProgramApe") + + self.prompt_file_path = self._get_path( + "prompt_file_path", + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "code_generator_prompt_plugin_only.yaml", + ), + ) + self.prompt_compression = self._get_bool("prompt_compression", False) + self.compression_prompt_path = self._get_path( + "compression_prompt_path", + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "compression_prompt.yaml", + ), + ) + self.enable_auto_plugin_selection = self._get_bool("enable_auto_plugin_selection", False) + self.auto_plugin_selection_topk = self._get_int("auto_plugin_selection_topk", 3) + + +class CodeGeneratorPluginOnly(Role): + @inject + def __init__( + self, + config: CodeGeneratorConfig, + plugin_registry: PluginRegistry, + logger: TelemetryLogger, + llm_api: LLMApi, + ): + self.config = config + self.logger = logger + self.llm_api = llm_api + + self.role_name = self.config.role_name + + self.post_translator = PostTranslator(logger) + self.prompt_data = read_yaml(self.config.prompt_file_path) + self.plugin_pool = plugin_registry.get_list() + self.instruction_template = self.prompt_data["content"] + + if self.config.enable_auto_plugin_selection: + self.plugin_selector = PluginSelector(plugin_registry, self.llm_api) + self.plugin_selector.generate_plugin_embeddings() + logger.info("Plugin embeddings generated") + self.selected_plugin_pool = SelectedPluginPool() + + def select_plugins_for_prompt( + self, + user_query, + ) -> List[PluginEntry]: + selected_plugins = self.plugin_selector.plugin_select( + user_query, + self.config.auto_plugin_selection_topk, + ) + self.selected_plugin_pool.add_selected_plugins(selected_plugins) + self.logger.info(f"Selected plugins: {[p.name for p in selected_plugins]}") + self.logger.info(f"Selected plugin pool: {[p.name for p in self.selected_plugin_pool.get_plugins()]}") + + return self.selected_plugin_pool.get_plugins() + + def reply(self, memory: Memory, event_handler: callable) -> Post: + # extract all rounds from memory + rounds = memory.get_role_rounds( + role="CodeInterpreter", + include_failure_rounds=False, + ) + + user_query = rounds[-1].user_query + if self.config.enable_auto_plugin_selection: + self.plugin_pool = self.select_plugins_for_prompt(user_query) + + # obtain the user query from the last round + prompt, tools = compose_prompt( + system_instructions=self.instruction_template.format( + ROLE_NAME=self.role_name, + ), + rounds=rounds, + plugin_pool=self.plugin_pool, + ) + post = Post.create(message=None, send_from="CodeInterpreter", send_to="Planner") + + llm_response = self.llm_api.chat_completion(messages=prompt, tools=tools, stream=False) + if llm_response["name"] == "assistant": + post.message = llm_response["content"] + event_handler("CodeInterpreter->Planner", post.message) + return post + elif llm_response["name"] == "tool_calls": + post.add_attachment(Attachment.create(type="function", content=llm_response["content"])) + event_handler("function", llm_response["content"]) + + if self.config.enable_auto_plugin_selection: + # here the code is in json format, not really code + self.selected_plugin_pool.filter_unused_plugins(code=llm_response["content"]) + return post + else: + raise ValueError(f"Unexpected response from LLM: {llm_response}") + + +def compose_prompt(system_instructions: str, rounds: List[Round], plugin_pool: List[PluginEntry]) -> Tuple[List, List]: + functions = [plugin.format_function_calling() for plugin in plugin_pool] + prompt = [format_chat_message(role="system", message=system_instructions)] + for _round in rounds: + for post in _round.post_list: + if post.send_from == "Planner" and post.send_to == "CodeInterpreter": + user_query = post.message + prompt.append(format_chat_message(role="user", message=user_query)) + elif post.send_from == "CodeInterpreter" and post.send_to == "Planner": + assistant_message = post.message + prompt.append(format_chat_message(role="assistant", message=assistant_message)) + + return prompt, functions diff --git a/taskweaver/code_interpreter/code_generator/code_generator_json_prompt.yaml b/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml similarity index 100% rename from taskweaver/code_interpreter/code_generator/code_generator_json_prompt.yaml rename to taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml diff --git a/taskweaver/code_interpreter/code_generator/code_generator_prompt_plugin_only.yaml b/taskweaver/code_interpreter/code_generator/code_generator_prompt_plugin_only.yaml new file mode 100644 index 00000000..4f69f6cc --- /dev/null +++ b/taskweaver/code_interpreter/code_generator/code_generator_prompt_plugin_only.yaml @@ -0,0 +1,4 @@ +version: 0.1 +content: |- + {ROLE_NAME} can understand the user request and leverage pre-defined tools to complete tasks. + diff --git a/taskweaver/code_interpreter/code_interpreter_plugin_only.py b/taskweaver/code_interpreter/code_interpreter_plugin_only.py new file mode 100644 index 00000000..d43c336a --- /dev/null +++ b/taskweaver/code_interpreter/code_interpreter_plugin_only.py @@ -0,0 +1,113 @@ +import json +from typing import Literal, Optional + +from injector import inject + +from taskweaver.code_interpreter.code_executor import CodeExecutor +from taskweaver.code_interpreter.code_generator import CodeGeneratorPluginOnly +from taskweaver.config.module_config import ModuleConfig +from taskweaver.logging import TelemetryLogger +from taskweaver.memory import Attachment, Memory, Post +from taskweaver.role import Role + + +class CodeInterpreterConfig(ModuleConfig): + def _configure(self): + self._set_name("code_interpreter_plugin_only") + self.use_local_uri = self._get_bool("use_local_uri", False) + self.max_retry_count = self._get_int("max_retry_count", 3) + + +def update_verification( + response: Post, + status: Literal["NONE", "INCORRECT", "CORRECT"] = "NONE", + error: str = "No verification is done.", +): + response.add_attachment(Attachment.create("verification", status)) + response.add_attachment( + Attachment.create("code_error", error), + ) + + +def update_execution( + response: Post, + status: Literal["NONE", "SUCCESS", "FAILURE"] = "NONE", + result: str = "No code is executed.", +): + response.add_attachment(Attachment.create("execution_status", status)) + response.add_attachment( + Attachment.create("execution_result", result), + ) + + +class CodeInterpreterPluginOnly(Role): + @inject + def __init__( + self, + generator: CodeGeneratorPluginOnly, + executor: CodeExecutor, + logger: TelemetryLogger, + config: CodeInterpreterConfig, + ): + self.generator = generator + self.executor = executor + self.logger = logger + self.config = config + self.retry_count = 0 + + self.logger.info("CodeInterpreter initialized successfully.") + + def reply( + self, + memory: Memory, + event_handler: callable, + prompt_log_path: Optional[str] = None, + use_back_up_engine: Optional[bool] = False, + ) -> Post: + response: Post = self.generator.reply( + memory, + event_handler, + ) + + if response.message is not None: + return response + + functions = json.loads(response.get_attachment(type="function")[0]) + if len(functions) > 0: + code = [] + for i, f in enumerate(functions): + function_name = f["name"] + function_args = json.loads(f["arguments"]) + function_call = ( + f"r{i}={function_name}(" + + ", ".join( + [ + f'{key}="{value}"' if isinstance(value, str) else f"{key}={value}" + for key, value in function_args.items() + ], + ) + + ")" + ) + code.append(function_call) + code.append(f'{", ".join([f"r{i}" for i in range(len(functions))])}') + + event_handler("code", "\n".join(code)) + exec_result = self.executor.execute_code( + exec_id=response.id, + code="\n".join(code), + ) + if exec_result.is_success: + response.message = self.executor.format_code_output( + exec_result, + with_code=True, + use_local_uri=self.config.use_local_uri, + ) + event_handler("CodeInterpreter-> Planner", response.message) + else: + response.message = self.executor.format_code_output( + exec_result, + with_code=True, + use_local_uri=self.config.use_local_uri, + ) + event_handler("CodeInterpreter-> Planner", response.message) + return response diff --git a/taskweaver/llm/__init__.py b/taskweaver/llm/__init__.py index f6a725b3..314638d3 100644 --- a/taskweaver/llm/__init__.py +++ b/taskweaver/llm/__init__.py @@ -59,6 +59,7 @@ def chat_completion( max_tokens: Optional[int] = None, top_p: Optional[float] = None, stop: Optional[List[str]] = None, + tools: Optional[List] = None, **kwargs: Any, ) -> ChatMessageType: msg: ChatMessageType = format_chat_message("assistant", "") @@ -70,10 +71,13 @@ def chat_completion( max_tokens, top_p, stop, + tools, **kwargs, ): msg["role"] = msg_chunk["role"] msg["content"] += msg_chunk["content"] + if "name" in msg_chunk: + msg["name"] = msg_chunk["name"] return msg def chat_completion_stream( diff --git a/taskweaver/llm/base.py b/taskweaver/llm/base.py index 0efc77c2..2bd6350a 100644 --- a/taskweaver/llm/base.py +++ b/taskweaver/llm/base.py @@ -70,6 +70,7 @@ def chat_completion( max_tokens: Optional[int] = None, top_p: Optional[float] = None, stop: Optional[List[str]] = None, + tools: Optional[List] = None, **kwargs: Any, ) -> Generator[ChatMessageType, None, None]: """ diff --git a/taskweaver/llm/openai.py b/taskweaver/llm/openai.py index 5e9bd959..f891ee1e 100644 --- a/taskweaver/llm/openai.py +++ b/taskweaver/llm/openai.py @@ -4,6 +4,7 @@ import openai from injector import inject from openai import AzureOpenAI, OpenAI +from openai._types import NOT_GIVEN from taskweaver.llm.util import ChatMessageType, format_chat_message @@ -134,6 +135,7 @@ def chat_completion( max_tokens: Optional[int] = None, top_p: Optional[float] = None, stop: Optional[List[str]] = None, + tools: Optional[List] = NOT_GIVEN, **kwargs: Any, ) -> Generator[ChatMessageType, None, None]: engine = self.config.model @@ -148,6 +150,13 @@ def chat_completion( try: if use_backup_engine: engine = backup_engine + + if tools is not NOT_GIVEN and tools is not None: + stream = False + tool_choice = "auto" + else: + tools = NOT_GIVEN + tool_choice = NOT_GIVEN res: Any = self.client.chat.completions.create( model=engine, messages=messages, # type: ignore @@ -162,6 +171,8 @@ def chat_completion( response_format=( {"type": "json_object"} if self.config.response_format == "json_object" else None # type: ignore ), + tool_choice=tool_choice, + tools=tools, ) if stream: role: Any = None @@ -184,7 +195,17 @@ def chat_completion( response: ChatMessageType = format_chat_message( role=oai_response.role if oai_response.role is not None else "assistant", message=oai_response.content if oai_response.content is not None else "", + name="assistant", ) + if oai_response.tool_calls is not None: + response["name"] = "tool_calls" + response["content"] = ( + "[" + + ",".join( + [t.function.model_dump_json() for t in oai_response.tool_calls], + ) + + "]" + ) yield response except openai.APITimeoutError as e: diff --git a/taskweaver/memory/plugin.py b/taskweaver/memory/plugin.py index 386094a4..80b4070e 100644 --- a/taskweaver/memory/plugin.py +++ b/taskweaver/memory/plugin.py @@ -156,6 +156,36 @@ def to_dict(self): "enabled": self.enabled, } + def format_function_calling(self) -> Dict: + def map_type(t: str) -> str: + if t.lower() == "string" or t.lower() == "str" or t.lower() == "text": + return "string" + if t.lower() == "integer" or t.lower() == "int": + return "integer" + if t.lower() == "float" or t.lower() == "double" or t.lower() == "number": + return "number" + if t.lower() == "boolean" or t.lower() == "bool": + return "boolean" + if t.lower() == "null" or t.lower() == "none": + return "null" + raise Exception(f"unknown type {t}") + + function = {"type": "function", "function": {}} + required_params = [] + function["function"]["name"] = self.name + function["function"]["description"] = self.spec.description + function["function"]["parameters"] = {"type": "object", "properties": {}} + for arg in self.spec.args: + function["function"]["parameters"]["properties"][arg.name] = { + "type": map_type(arg.type), + "description": arg.description, + } + if arg.required: + required_params.append(arg.name) + function["function"]["parameters"]["required"] = required_params + + return function + class PluginRegistry(ComponentRegistry[PluginEntry]): def __init__( diff --git a/taskweaver/session/session.py b/taskweaver/session/session.py index 82d32ba1..147a2560 100644 --- a/taskweaver/session/session.py +++ b/taskweaver/session/session.py @@ -4,7 +4,7 @@ from injector import Injector, inject -from taskweaver.code_interpreter import CodeInterpreter +from taskweaver.code_interpreter import CodeInterpreterPluginOnly from taskweaver.code_interpreter.code_executor import CodeExecutor from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger @@ -60,7 +60,7 @@ def __init__( }, ) self.session_injector.binder.bind(CodeExecutor, self.code_executor) - self.code_interpreter = self.session_injector.get(CodeInterpreter) + self.code_interpreter = self.session_injector.get(CodeInterpreterPluginOnly) self.max_internal_chat_round_num = self.config.max_internal_chat_round_num self.internal_chat_num = 0 From 3ca4793a6a63a45f00cd8596e3c1a1dd99f571ac Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Thu, 21 Dec 2023 15:24:02 +0800 Subject: [PATCH 02/13] remove executor_name --- taskweaver/code_interpreter/code_generator/code_generator.py | 3 --- .../code_interpreter/code_generator/code_generator_prompt.yaml | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/taskweaver/code_interpreter/code_generator/code_generator.py b/taskweaver/code_interpreter/code_generator/code_generator.py index 178c8c07..878d47d9 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator.py +++ b/taskweaver/code_interpreter/code_generator/code_generator.py @@ -20,7 +20,6 @@ class CodeGeneratorConfig(ModuleConfig): def _configure(self) -> None: self._set_name("code_generator") self.role_name = self._get_str("role_name", "ProgramApe") - self.executor_name = self._get_str("executor_name", "CodeExecutor") self.load_plugin = self._get_bool("load_plugin", True) self.load_example = self._get_bool("load_example", True) self.prompt_file_path = self._get_path( @@ -67,7 +66,6 @@ def __init__( self.llm_api = llm_api self.role_name = self.config.role_name - self.executor_name = self.config.executor_name self.post_translator = PostTranslator(logger) self.prompt_data = read_yaml(self.config.prompt_file_path) @@ -86,7 +84,6 @@ def __init__( self.instruction = self.instruction_template.format( ROLE_NAME=self.role_name, - EXECUTOR_NAME=self.executor_name, ) self.round_compressor: RoundCompressor = round_compressor diff --git a/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml b/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml index 3403cef7..cda67f5c 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml +++ b/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml @@ -25,7 +25,7 @@ content: |- - {ROLE_NAME} generates the reply to the user with 'type' that must be one of the following: - "thought": the thoughts on the intermediate steps - "sample": textual descriptions including the sample code - - "python": the code that can be executed by {EXECUTOR_NAME}; comments must be added calling functions from the pre-defined plugins, including the description of the function and the parameters. + - "python": the code that can be executed by the User; comments must be added calling functions from the pre-defined plugins, including the description of the function and the parameters. - "text": the direct response in text without code - The "response" array can include multiple thought replies, but it can have only one of sample, python, or text, exclusively. - The value of "content" is a string that contains the actual content and {ROLE_NAME} must be very careful about escaping the special characters (e.g., '\', '/', and '"') in the string for JSON format. From 1a3ed4db5821ae26181f28b7ee09e9ff47e4eeda Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Thu, 21 Dec 2023 15:40:19 +0800 Subject: [PATCH 03/13] fix ut --- tests/unit_tests/data/prompts/generator_prompt.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/data/prompts/generator_prompt.yaml b/tests/unit_tests/data/prompts/generator_prompt.yaml index 16818a27..a062802b 100644 --- a/tests/unit_tests/data/prompts/generator_prompt.yaml +++ b/tests/unit_tests/data/prompts/generator_prompt.yaml @@ -25,7 +25,7 @@ content: |- - {ROLE_NAME} generates the reply to the user with 'type' that must be one of the following: - "thought": the thoughts on the intermediate steps - "sample": textual descriptions including the sample code - - "python": the code that can be executed by {EXECUTOR_NAME}; comments must be added calling functions from the pre-defined plugins, including the description of the function and the parameters. + - "python": the code that can be executed by the User; comments must be added calling functions from the pre-defined plugins, including the description of the function and the parameters. - "text": the direct response in text without code - The "response" array can include multiple thought replies, but it can have only one of sample, python, or text, exclusively. - The value of "content" is a string that contains the actual content and {ROLE_NAME} must be very careful about escaping the special characters (e.g., '\', '/', and '"') in the string for JSON format. From f6ffee216efbf1cf45786895ebf10b409e3316b7 Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Thu, 21 Dec 2023 16:20:58 +0800 Subject: [PATCH 04/13] fix argument issues --- .../code_generator_plugin_only.py | 17 ++++++++--- .../code_interpreter_plugin_only.py | 29 +++---------------- taskweaver/llm/__init__.py | 2 -- taskweaver/llm/openai.py | 24 +++++++-------- taskweaver/memory/attachment.py | 3 ++ 5 files changed, 32 insertions(+), 43 deletions(-) diff --git a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py index 80457fee..be0946d5 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py +++ b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py @@ -3,12 +3,12 @@ from injector import inject -from taskweaver.code_interpreter.code_generator import CodeGeneratorConfig from taskweaver.code_interpreter.code_generator.plugin_selection import PluginSelector, SelectedPluginPool from taskweaver.config.module_config import ModuleConfig from taskweaver.llm import LLMApi, format_chat_message from taskweaver.logging import TelemetryLogger from taskweaver.memory import Attachment, Memory, Post, Round +from taskweaver.memory.attachment import AttachmentType from taskweaver.memory.plugin import PluginEntry, PluginRegistry from taskweaver.role import PostTranslator, Role from taskweaver.utils import read_yaml @@ -42,7 +42,7 @@ class CodeGeneratorPluginOnly(Role): @inject def __init__( self, - config: CodeGeneratorConfig, + config: CodeGeneratorPluginOnlyConfig, plugin_registry: PluginRegistry, logger: TelemetryLogger, llm_api: LLMApi, @@ -99,13 +99,19 @@ def reply(self, memory: Memory, event_handler: callable) -> Post: ) post = Post.create(message=None, send_from="CodeInterpreter", send_to="Planner") - llm_response = self.llm_api.chat_completion(messages=prompt, tools=tools, stream=False) + llm_response = self.llm_api.chat_completion( + messages=prompt, + tools=tools, + tool_choice="auto", + response_format=None, + stream=False, + ) if llm_response["name"] == "assistant": post.message = llm_response["content"] event_handler("CodeInterpreter->Planner", post.message) return post elif llm_response["name"] == "tool_calls": - post.add_attachment(Attachment.create(type="function", content=llm_response["content"])) + post.add_attachment(Attachment.create(type=AttachmentType.function, content=llm_response["content"])) event_handler("function", llm_response["content"]) if self.config.enable_auto_plugin_selection: @@ -115,6 +121,9 @@ def reply(self, memory: Memory, event_handler: callable) -> Post: else: raise ValueError(f"Unexpected response from LLM: {llm_response}") + def configure_verification(self, code_verification_on, plugin_only, allowed_modules): + pass + def compose_prompt(system_instructions: str, rounds: List[Round], plugin_pool: List[PluginEntry]) -> Tuple[List, List]: functions = [plugin.format_function_calling() for plugin in plugin_pool] diff --git a/taskweaver/code_interpreter/code_interpreter_plugin_only.py b/taskweaver/code_interpreter/code_interpreter_plugin_only.py index d43c336a..d440e4ab 100644 --- a/taskweaver/code_interpreter/code_interpreter_plugin_only.py +++ b/taskweaver/code_interpreter/code_interpreter_plugin_only.py @@ -1,5 +1,5 @@ import json -from typing import Literal, Optional +from typing import Optional from injector import inject @@ -7,7 +7,8 @@ from taskweaver.code_interpreter.code_generator import CodeGeneratorPluginOnly from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger -from taskweaver.memory import Attachment, Memory, Post +from taskweaver.memory import Memory, Post +from taskweaver.memory.attachment import AttachmentType from taskweaver.role import Role @@ -18,28 +19,6 @@ def _configure(self): self.max_retry_count = self._get_int("max_retry_count", 3) -def update_verification( - response: Post, - status: Literal["NONE", "INCORRECT", "CORRECT"] = "NONE", - error: str = "No verification is done.", -): - response.add_attachment(Attachment.create("verification", status)) - response.add_attachment( - Attachment.create("code_error", error), - ) - - -def update_execution( - response: Post, - status: Literal["NONE", "SUCCESS", "FAILURE"] = "NONE", - result: str = "No code is executed.", -): - response.add_attachment(Attachment.create("execution_status", status)) - response.add_attachment( - Attachment.create("execution_result", result), - ) - - class CodeInterpreterPluginOnly(Role): @inject def __init__( @@ -72,7 +51,7 @@ def reply( if response.message is not None: return response - functions = json.loads(response.get_attachment(type="function")[0]) + functions = json.loads(response.get_attachment(type=AttachmentType.function)[0]) if len(functions) > 0: code = [] for i, f in enumerate(functions): diff --git a/taskweaver/llm/__init__.py b/taskweaver/llm/__init__.py index 03e47e2e..03167d51 100644 --- a/taskweaver/llm/__init__.py +++ b/taskweaver/llm/__init__.py @@ -84,7 +84,6 @@ def chat_completion( max_tokens: Optional[int] = None, top_p: Optional[float] = None, stop: Optional[List[str]] = None, - tools: Optional[List] = None, **kwargs: Any, ) -> ChatMessageType: msg: ChatMessageType = format_chat_message("assistant", "") @@ -96,7 +95,6 @@ def chat_completion( max_tokens, top_p, stop, - tools, **kwargs, ): msg["role"] = msg_chunk["role"] diff --git a/taskweaver/llm/openai.py b/taskweaver/llm/openai.py index 55380e36..60875cd1 100644 --- a/taskweaver/llm/openai.py +++ b/taskweaver/llm/openai.py @@ -4,7 +4,6 @@ import openai from injector import inject from openai import AzureOpenAI, OpenAI -from openai._types import NOT_GIVEN from taskweaver.llm.util import ChatMessageType, format_chat_message @@ -135,7 +134,6 @@ def chat_completion( max_tokens: Optional[int] = None, top_p: Optional[float] = None, stop: Optional[List[str]] = None, - tools: Optional[List] = NOT_GIVEN, **kwargs: Any, ) -> Generator[ChatMessageType, None, None]: engine = self.config.model @@ -151,12 +149,17 @@ def chat_completion( if use_backup_engine: engine = backup_engine - if tools is not NOT_GIVEN and tools is not None: - stream = False - tool_choice = "auto" + tools_kwargs = {} + if "tools" in kwargs and "tool_choice" in kwargs: + tools_kwargs["tools"] = kwargs["tools"] + tools_kwargs["tool_choice"] = kwargs["tool_choice"] + if "response_format" in kwargs: + response_format = kwargs["response_format"] + elif self.config.response_format == "json_object": + response_format = {"type": "json_object"} else: - tools = NOT_GIVEN - tool_choice = NOT_GIVEN + response_format = None + res: Any = self.client.chat.completions.create( model=engine, messages=messages, # type: ignore @@ -168,11 +171,8 @@ def chat_completion( stop=stop, stream=stream, seed=seed, - response_format=( - {"type": "json_object"} if self.config.response_format == "json_object" else None # type: ignore - ), - tool_choice=tool_choice, - tools=tools, + response_format=response_format, + **tools_kwargs, ) if stream: role: Any = None diff --git a/taskweaver/memory/attachment.py b/taskweaver/memory/attachment.py index 4c3f658a..941253a1 100644 --- a/taskweaver/memory/attachment.py +++ b/taskweaver/memory/attachment.py @@ -35,6 +35,9 @@ class AttachmentType(Enum): # CodeInterpreter - revise code revise_message = "revise_message" + # function calling + function = "function" + # Misc invalid_response = "invalid_response" From 127d48d6b57b365ecf8e90893d66472375476e51 Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Thu, 21 Dec 2023 16:44:03 +0800 Subject: [PATCH 05/13] adding new role function --- .../code_generator/code_generator_plugin_only.py | 7 ++----- taskweaver/llm/openai.py | 2 +- taskweaver/llm/util.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py index be0946d5..8bf46ef3 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py +++ b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py @@ -106,11 +106,11 @@ def reply(self, memory: Memory, event_handler: callable) -> Post: response_format=None, stream=False, ) - if llm_response["name"] == "assistant": + if llm_response["role"] == "assistant": post.message = llm_response["content"] event_handler("CodeInterpreter->Planner", post.message) return post - elif llm_response["name"] == "tool_calls": + elif llm_response["role"] == "function": post.add_attachment(Attachment.create(type=AttachmentType.function, content=llm_response["content"])) event_handler("function", llm_response["content"]) @@ -121,9 +121,6 @@ def reply(self, memory: Memory, event_handler: callable) -> Post: else: raise ValueError(f"Unexpected response from LLM: {llm_response}") - def configure_verification(self, code_verification_on, plugin_only, allowed_modules): - pass - def compose_prompt(system_instructions: str, rounds: List[Round], plugin_pool: List[PluginEntry]) -> Tuple[List, List]: functions = [plugin.format_function_calling() for plugin in plugin_pool] diff --git a/taskweaver/llm/openai.py b/taskweaver/llm/openai.py index 60875cd1..614a6837 100644 --- a/taskweaver/llm/openai.py +++ b/taskweaver/llm/openai.py @@ -198,7 +198,7 @@ def chat_completion( name="assistant", ) if oai_response.tool_calls is not None: - response["name"] = "tool_calls" + response["role"] = "function" response["content"] = ( "[" + ",".join( diff --git a/taskweaver/llm/util.py b/taskweaver/llm/util.py index 3610bb5d..dc7a346c 100644 --- a/taskweaver/llm/util.py +++ b/taskweaver/llm/util.py @@ -1,6 +1,6 @@ from typing import Dict, Literal, Optional -ChatMessageRoleType = Literal["system", "user", "assistant"] +ChatMessageRoleType = Literal["system", "user", "assistant", "function"] ChatMessageType = Dict[Literal["role", "name", "content"], str] From e075a128ddcb7c6400217d5d3d98565574b94592 Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Fri, 22 Dec 2023 11:28:20 +0800 Subject: [PATCH 06/13] before ut --- .../example1-codeinterpreter.yaml | 1 - .../example2-codeinterpreter.yaml | 1 - .../example3-codeinterpreter.yaml | 70 ------------------- project/plugins/klarna_search.yaml | 2 +- .../code_generator/code_generator.py | 19 +---- .../code_interpreter/code_interpreter.py | 10 +-- .../code_interpreter_plugin_only.py | 28 ++++---- taskweaver/llm/openai.py | 1 - taskweaver/misc/example.py | 7 +- taskweaver/session/session.py | 10 +-- 10 files changed, 25 insertions(+), 124 deletions(-) delete mode 100644 project/codeinterpreter_examples/example3-codeinterpreter.yaml diff --git a/project/codeinterpreter_examples/example1-codeinterpreter.yaml b/project/codeinterpreter_examples/example1-codeinterpreter.yaml index 97a7c717..21e4a1f6 100644 --- a/project/codeinterpreter_examples/example1-codeinterpreter.yaml +++ b/project/codeinterpreter_examples/example1-codeinterpreter.yaml @@ -1,5 +1,4 @@ enabled: True -plugin_only: False rounds: - user_query: hello state: finished diff --git a/project/codeinterpreter_examples/example2-codeinterpreter.yaml b/project/codeinterpreter_examples/example2-codeinterpreter.yaml index e17108c7..f2d77f8f 100644 --- a/project/codeinterpreter_examples/example2-codeinterpreter.yaml +++ b/project/codeinterpreter_examples/example2-codeinterpreter.yaml @@ -1,5 +1,4 @@ enabled: True -plugin_only: False rounds: - user_query: read file /abc/def.txt state: finished diff --git a/project/codeinterpreter_examples/example3-codeinterpreter.yaml b/project/codeinterpreter_examples/example3-codeinterpreter.yaml deleted file mode 100644 index 5f6a53af..00000000 --- a/project/codeinterpreter_examples/example3-codeinterpreter.yaml +++ /dev/null @@ -1,70 +0,0 @@ -enabled: True -plugin_only: True -plugins: - - name: read_csv - enabled: true - required: false - description: read the content of a csv file - - parameters: - - name: file_path - type: string - required: true - description: the path of the file - - returns: - - name: df - type: DataFrame - description: This DataFrame contains the content of the csv file. - - name: description - type: str - description: This is a string describing the csv schema. - - - name: write_csv - enabled: true - required: false - description: write the content of a DataFrame to a csv file - - parameters: - - name: df - type: DataFrame - required: true - description: the DataFrame to be written to the csv file - - name: file_path - type: string - required: true - description: the path of the file - - returns: - - name: description - type: str - description: This is a string describing success or failure of the write operation. -rounds: - - user_query: read file /abc/def.csv and write to /abc/backup.csv - state: finished - post_list: - - message: read file /abc/def.csv and write to /abc/backup.csv. You can only use the pre-defined plugins. - send_from: Planner - send_to: CodeInterpreter - attachment_list: [] - - message: I have read the file /abc/def.csv and written the content to /abc/backup.csv. - send_from: CodeInterpreter - send_to: Planner - attachment_list: - - type: thought - content: "{ROLE_NAME} will generate a code snippet to read the file /abc/def.csv and add 1 to each value in column \"value\"." - - type: thought - content: "{ROLE_NAME} is prohibited to generate any code other than variable assignments and plugin calls." - - type: python - content: |- - df = read_csv("/abc/def.csv") - status = write_csv(df, "/abc/backup.csv") - status - - type: verification - content: CORRECT - - type: code_error - content: No code error. - - type: execution_status - content: SUCCESS - - type: execution_result - content: The file /abc/def.csv has been read and the content has been written to /abc/backup.csv. diff --git a/project/plugins/klarna_search.yaml b/project/plugins/klarna_search.yaml index 238ca844..187e8c1e 100644 --- a/project/plugins/klarna_search.yaml +++ b/project/plugins/klarna_search.yaml @@ -1,5 +1,5 @@ name: klarna_search -enabled: false +enabled: true required: false description: >- Search and compare prices from thousands of online shops. Only available in the US. diff --git a/taskweaver/code_interpreter/code_generator/code_generator.py b/taskweaver/code_interpreter/code_generator/code_generator.py index 878d47d9..c11d15e7 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator.py +++ b/taskweaver/code_interpreter/code_generator/code_generator.py @@ -80,7 +80,6 @@ def __init__( self.examples = None self.code_verification_on: bool = False self.allowed_modules: List[str] = [] - self.plugin_only: bool = False self.instruction = self.instruction_template.format( ROLE_NAME=self.role_name, @@ -98,10 +97,8 @@ def __init__( def configure_verification( self, code_verification_on: bool, - plugin_only: bool, allowed_modules: Optional[List[str]] = None, ): - self.plugin_only = plugin_only self.allowed_modules = allowed_modules if allowed_modules is not None else [] self.code_verification_on = code_verification_on @@ -113,23 +110,13 @@ def compose_verification_requirements( if not self.code_verification_on: return "" - if self.plugin_only: - requirements.append( - f"- {self.role_name} should only use the following plugins and" - + " Python built-in functions to complete the task: " - + ", ".join([f"{plugin.name}" for plugin in plugin_list]), - ) - requirements.append( - f"- {self.role_name} cannot define new functions or plugins.", - ) - if len(self.allowed_modules) > 0: requirements.append( f"- {self.role_name} can only import the following Python modules: " + ", ".join([f"{module}" for module in self.allowed_modules]), ) - if len(self.allowed_modules) == 0 and self.plugin_only: + if len(self.allowed_modules) == 0: requirements.append(f"- {self.role_name} cannot import any Python modules.") return "\n".join(requirements) @@ -141,7 +128,7 @@ def compose_prompt( chat_history = [format_chat_message(role="system", message=self.instruction)] if self.examples is None: - self.examples = self.load_examples(plugin_only=self.plugin_only) + self.examples = self.load_examples() for i, example in enumerate(self.examples): chat_history.extend( self.compose_conversation(example.rounds, example.plugins, add_requirements=False), @@ -366,12 +353,10 @@ def format_plugins( def load_examples( self, - plugin_only: bool, ) -> List[Conversation]: if self.config.load_example: return load_examples( folder=self.config.example_base_path, - plugin_only=plugin_only, ) return [] diff --git a/taskweaver/code_interpreter/code_interpreter.py b/taskweaver/code_interpreter/code_interpreter.py index f649c554..7153a61e 100644 --- a/taskweaver/code_interpreter/code_interpreter.py +++ b/taskweaver/code_interpreter/code_interpreter.py @@ -22,16 +22,11 @@ def _configure(self): # for verification self.code_verification_on = self._get_bool("code_verification_on", False) - self.plugin_only = self._get_bool("plugin_only", False) self.allowed_modules = self._get_list( "allowed_modules", ["pandas", "matplotlib", "numpy", "sklearn", "scipy", "seaborn", "datetime", "typing"], ) - if self.plugin_only: - self.code_verification_on = True - self.allowed_modules = [] - def update_verification( response: Post, @@ -69,7 +64,6 @@ def __init__( self.generator = generator self.generator.configure_verification( code_verification_on=self.config.code_verification_on, - plugin_only=self.config.plugin_only, allowed_modules=self.config.allowed_modules, ) @@ -130,8 +124,8 @@ def reply( code.content, [plugin.name for plugin in self.generator.get_plugin_pool()], self.config.code_verification_on, - self.config.plugin_only, - self.config.allowed_modules, + plugin_only=False, + allowed_modules=self.config.allowed_modules, ) if code_verify_errors is None: diff --git a/taskweaver/code_interpreter/code_interpreter_plugin_only.py b/taskweaver/code_interpreter/code_interpreter_plugin_only.py index d440e4ab..7ee3b379 100644 --- a/taskweaver/code_interpreter/code_interpreter_plugin_only.py +++ b/taskweaver/code_interpreter/code_interpreter_plugin_only.py @@ -33,6 +33,7 @@ def __init__( self.logger = logger self.config = config self.retry_count = 0 + self.return_id = 0 self.logger.info("CodeInterpreter initialized successfully.") @@ -58,7 +59,7 @@ def reply( function_name = f["name"] function_args = json.loads(f["arguments"]) function_call = ( - f"r{i}={function_name}(" + f"r{self.return_id + i}={function_name}(" + ", ".join( [ f'{key}="{value}"' if isinstance(value, str) else f"{key}={value}" @@ -68,25 +69,20 @@ def reply( + ")" ) code.append(function_call) - code.append(f'{", ".join([f"r{i}" for i in range(len(functions))])}') + code.append(f'{", ".join([f"r{self.return_id + i}" for i in range(len(functions))])}') + self.return_id += len(functions) event_handler("code", "\n".join(code)) exec_result = self.executor.execute_code( exec_id=response.id, code="\n".join(code), ) - if exec_result.is_success: - response.message = self.executor.format_code_output( - exec_result, - with_code=True, - use_local_uri=self.config.use_local_uri, - ) - event_handler("CodeInterpreter-> Planner", response.message) - else: - response.message = self.executor.format_code_output( - exec_result, - with_code=True, - use_local_uri=self.config.use_local_uri, - ) - event_handler("CodeInterpreter-> Planner", response.message) + + response.message = self.executor.format_code_output( + exec_result, + with_code=True, + use_local_uri=self.config.use_local_uri, + ) + event_handler("CodeInterpreter-> Planner", response.message) + return response diff --git a/taskweaver/llm/openai.py b/taskweaver/llm/openai.py index 614a6837..4e1d3e2a 100644 --- a/taskweaver/llm/openai.py +++ b/taskweaver/llm/openai.py @@ -195,7 +195,6 @@ def chat_completion( response: ChatMessageType = format_chat_message( role=oai_response.role if oai_response.role is not None else "assistant", message=oai_response.content if oai_response.content is not None else "", - name="assistant", ) if oai_response.tool_calls is not None: response["role"] = "function" diff --git a/taskweaver/misc/example.py b/taskweaver/misc/example.py index c4f76672..de70b60d 100644 --- a/taskweaver/misc/example.py +++ b/taskweaver/misc/example.py @@ -5,7 +5,7 @@ from taskweaver.memory.conversation import Conversation -def load_examples(folder: str, plugin_only: bool = False) -> List[Conversation]: +def load_examples(folder: str) -> List[Conversation]: """ Load all the examples from a folder. @@ -17,8 +17,5 @@ def load_examples(folder: str, plugin_only: bool = False) -> List[Conversation]: example_conv_pool: List[Conversation] = [] for yaml_path in example_file_list: conversation = Conversation.from_yaml(yaml_path) - if plugin_only and conversation.plugin_only: - example_conv_pool.append(conversation) - elif not plugin_only and not conversation.plugin_only: - example_conv_pool.append(conversation) + example_conv_pool.append(conversation) return example_conv_pool diff --git a/taskweaver/session/session.py b/taskweaver/session/session.py index c1b78bff..94e43928 100644 --- a/taskweaver/session/session.py +++ b/taskweaver/session/session.py @@ -4,7 +4,7 @@ from injector import Injector, inject -from taskweaver.code_interpreter import CodeInterpreterPluginOnly +from taskweaver.code_interpreter import CodeInterpreter, CodeInterpreterPluginOnly from taskweaver.code_interpreter.code_executor import CodeExecutor from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger @@ -19,6 +19,7 @@ def _configure(self) -> None: self.code_interpreter_only = self._get_bool("code_interpreter_only", False) self.max_internal_chat_round_num = self._get_int("max_internal_chat_round_num", 10) + self.plugin_only_mode = self._get_bool("plugin_only_mode", False) class Session: @@ -47,8 +48,6 @@ def __init__( self.session_var: Dict[str, str] = {} - # self.plugins = get_plugin_registry() - self.planner_config = self.session_injector.get(PlannerConfig) self.planner = self.session_injector.get(Planner) self.code_executor = self.session_injector.create_object( @@ -60,7 +59,10 @@ def __init__( }, ) self.session_injector.binder.bind(CodeExecutor, self.code_executor) - self.code_interpreter = self.session_injector.get(CodeInterpreterPluginOnly) + if self.config.plugin_only_mode: + self.code_interpreter = self.session_injector.get(CodeInterpreterPluginOnly) + else: + self.code_interpreter = self.session_injector.get(CodeInterpreter) self.max_internal_chat_round_num = self.config.max_internal_chat_round_num self.internal_chat_num = 0 From 7dcc3b29d367e4694b8d12db7774365fef52926f Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Fri, 22 Dec 2023 12:31:03 +0800 Subject: [PATCH 07/13] add ut --- .../code_generator_plugin_only.py | 16 +++-- .../code_interpreter_plugin_only.py | 8 +-- tests/unit_tests/test_function_calling.py | 65 +++++++++++++++++++ 3 files changed, 79 insertions(+), 10 deletions(-) create mode 100644 tests/unit_tests/test_function_calling.py diff --git a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py index 8bf46ef3..1a9abcd6 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py +++ b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py @@ -27,6 +27,8 @@ def _configure(self) -> None: ), ) self.prompt_compression = self._get_bool("prompt_compression", False) + assert self.prompt_compression is False, "Compression is not supported for plugin only mode." + self.compression_prompt_path = self._get_path( "compression_prompt_path", os.path.join( @@ -90,7 +92,7 @@ def reply(self, memory: Memory, event_handler: callable) -> Post: self.plugin_pool = self.select_plugins_for_prompt(user_query) # obtain the user query from the last round - prompt, tools = compose_prompt( + prompt, tools = _compose_prompt( system_instructions=self.instruction_template.format( ROLE_NAME=self.role_name, ), @@ -122,16 +124,18 @@ def reply(self, memory: Memory, event_handler: callable) -> Post: raise ValueError(f"Unexpected response from LLM: {llm_response}") -def compose_prompt(system_instructions: str, rounds: List[Round], plugin_pool: List[PluginEntry]) -> Tuple[List, List]: +def _compose_prompt( + system_instructions: str, + rounds: List[Round], + plugin_pool: List[PluginEntry], +) -> Tuple[List, List]: functions = [plugin.format_function_calling() for plugin in plugin_pool] prompt = [format_chat_message(role="system", message=system_instructions)] for _round in rounds: for post in _round.post_list: if post.send_from == "Planner" and post.send_to == "CodeInterpreter": - user_query = post.message - prompt.append(format_chat_message(role="user", message=user_query)) + prompt.append(format_chat_message(role="user", message=post.message)) elif post.send_from == "CodeInterpreter" and post.send_to == "Planner": - assistant_message = post.message - prompt.append(format_chat_message(role="assistant", message=assistant_message)) + prompt.append(format_chat_message(role="assistant", message=post.message)) return prompt, functions diff --git a/taskweaver/code_interpreter/code_interpreter_plugin_only.py b/taskweaver/code_interpreter/code_interpreter_plugin_only.py index 7ee3b379..4be9a4fe 100644 --- a/taskweaver/code_interpreter/code_interpreter_plugin_only.py +++ b/taskweaver/code_interpreter/code_interpreter_plugin_only.py @@ -33,7 +33,7 @@ def __init__( self.logger = logger self.config = config self.retry_count = 0 - self.return_id = 0 + self.return_index = 0 self.logger.info("CodeInterpreter initialized successfully.") @@ -59,7 +59,7 @@ def reply( function_name = f["name"] function_args = json.loads(f["arguments"]) function_call = ( - f"r{self.return_id + i}={function_name}(" + f"r{self.return_index + i}={function_name}(" + ", ".join( [ f'{key}="{value}"' if isinstance(value, str) else f"{key}={value}" @@ -69,8 +69,8 @@ def reply( + ")" ) code.append(function_call) - code.append(f'{", ".join([f"r{self.return_id + i}" for i in range(len(functions))])}') - self.return_id += len(functions) + code.append(f'{", ".join([f"r{self.return_index + i}" for i in range(len(functions))])}') + self.return_index += len(functions) event_handler("code", "\n".join(code)) exec_result = self.executor.execute_code( diff --git a/tests/unit_tests/test_function_calling.py b/tests/unit_tests/test_function_calling.py new file mode 100644 index 00000000..a719788d --- /dev/null +++ b/tests/unit_tests/test_function_calling.py @@ -0,0 +1,65 @@ +from taskweaver.memory.plugin import PluginEntry, PluginParameter, PluginSpec + + +def test_function_formatting(): + plugin = PluginEntry( + name="test", + impl="test", + spec=PluginSpec( + name="test", + description="test", + args=[ + PluginParameter( + name="arg1", + type="string", + description="arg1", + required=True, + ), + PluginParameter( + name="arg2", + type="integer", + description="arg2", + required=False, + ), + PluginParameter( + name="arg3", + type="float", + description="arg3", + required=False, + ), + PluginParameter( + name="arg4", + type="boolean", + description="arg4", + required=False, + ), + PluginParameter( + name="arg5", + type="none", + description="arg5", + required=False, + ), + ], + ), + config={"test_key": "test_val"}, + required=False, + enabled=True, + ) + assert plugin.format_function_calling() == { + "type": "function", + "function": { + "name": "test", + "description": "test", + "parameters": { + "type": "object", + "properties": { + "arg1": {"type": "string", "description": "arg1"}, + "arg2": {"type": "integer", "description": "arg2"}, + "arg3": {"type": "number", "description": "arg3"}, + "arg4": {"type": "boolean", "description": "arg4"}, + "arg5": {"type": "null", "description": "arg5"}, + }, + "required": ["arg1"], + }, + }, + } From 138cb6ce58ed6e99070045a5f7c097afaa2584f7 Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Fri, 22 Dec 2023 12:34:36 +0800 Subject: [PATCH 08/13] handle corner case --- taskweaver/code_interpreter/code_interpreter_plugin_only.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/taskweaver/code_interpreter/code_interpreter_plugin_only.py b/taskweaver/code_interpreter/code_interpreter_plugin_only.py index 4be9a4fe..658e72af 100644 --- a/taskweaver/code_interpreter/code_interpreter_plugin_only.py +++ b/taskweaver/code_interpreter/code_interpreter_plugin_only.py @@ -84,5 +84,8 @@ def reply( use_local_uri=self.config.use_local_uri, ) event_handler("CodeInterpreter-> Planner", response.message) + else: + response.message = "No code is generated because no function is selected." + event_handler("CodeInterpreter-> Planner", response.message) return response From 2bc8aea463cd7139d4162d610a5dc83e70b5ff7a Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Mon, 25 Dec 2023 12:15:25 +0800 Subject: [PATCH 09/13] fix ut and add plugin-only to planner --- project/plugins/anomaly_detection.yaml | 2 +- project/plugins/ascii_render.py | 5 +- project/plugins/ascii_render.yaml | 1 + project/plugins/klarna_search.yaml | 1 + project/plugins/paper_summary.yaml | 2 +- project/plugins/sql_pull_data.yaml | 2 +- project/plugins/tell_joke.py | 9 +- project/plugins/tell_joke.yaml | 7 +- .../code_generator_plugin_only.py | 2 +- .../code_generator/plugin_selection.py | 16 ++- taskweaver/memory/plugin.py | 4 + taskweaver/planner/planner.py | 10 +- taskweaver/session/session.py | 10 +- .../data/plugins/klarna_search.yaml | 1 + .../data/prompts/generator_plugin_only.yaml | 4 + tests/unit_tests/test_code_generator.py | 126 ++++++++---------- tests/unit_tests/test_function_calling.py | 1 + 17 files changed, 108 insertions(+), 95 deletions(-) create mode 100644 tests/unit_tests/data/prompts/generator_plugin_only.yaml diff --git a/project/plugins/anomaly_detection.yaml b/project/plugins/anomaly_detection.yaml index 2f8a3393..29c68cdc 100644 --- a/project/plugins/anomaly_detection.yaml +++ b/project/plugins/anomaly_detection.yaml @@ -1,5 +1,5 @@ name: anomaly_detection -enabled: false +enabled: true required: false description: >- anomaly_detection function identifies anomalies from an input DataFrame of diff --git a/project/plugins/ascii_render.py b/project/plugins/ascii_render.py index 23a8a752..dfd52c78 100644 --- a/project/plugins/ascii_render.py +++ b/project/plugins/ascii_render.py @@ -4,7 +4,10 @@ @register_plugin class AsciiRenderPlugin(Plugin): def __call__(self, text: str): - import pyfiglet + try: + import pyfiglet + except ImportError: + raise ImportError("Please install pyfiglet first.") ASCII_art_1 = pyfiglet.figlet_format(text, font="isometric1") result = ASCII_art_1 diff --git a/project/plugins/ascii_render.yaml b/project/plugins/ascii_render.yaml index 76f1dd9e..2e940c98 100644 --- a/project/plugins/ascii_render.yaml +++ b/project/plugins/ascii_render.yaml @@ -1,6 +1,7 @@ name: ascii_render enabled: true required: true +plugin_only: true description: >- This plugin renders the input text into ASCII art form. The input should be a string and the output is also a string in ASCII art. diff --git a/project/plugins/klarna_search.yaml b/project/plugins/klarna_search.yaml index 187e8c1e..ef0d1c9b 100644 --- a/project/plugins/klarna_search.yaml +++ b/project/plugins/klarna_search.yaml @@ -1,6 +1,7 @@ name: klarna_search enabled: true required: false +plugin_only: true description: >- Search and compare prices from thousands of online shops. Only available in the US. This plugin only takes user requests when searching for merchandise. diff --git a/project/plugins/paper_summary.yaml b/project/plugins/paper_summary.yaml index c73b1547..4a07306a 100644 --- a/project/plugins/paper_summary.yaml +++ b/project/plugins/paper_summary.yaml @@ -1,5 +1,5 @@ name: paper_summary -enabled: false +enabled: true required: false description: >- summarize_paper function iteratively summarizes a given paper page by page, diff --git a/project/plugins/sql_pull_data.yaml b/project/plugins/sql_pull_data.yaml index e643ba74..e2b343a3 100644 --- a/project/plugins/sql_pull_data.yaml +++ b/project/plugins/sql_pull_data.yaml @@ -1,5 +1,5 @@ name: sql_pull_data -enabled: false +enabled: true required: false description: >- Pull data from a SQL database. diff --git a/project/plugins/tell_joke.py b/project/plugins/tell_joke.py index d8b8866d..5207a80a 100644 --- a/project/plugins/tell_joke.py +++ b/project/plugins/tell_joke.py @@ -3,6 +3,11 @@ @register_plugin class TellJoke(Plugin): - def __call__(self, context: str): + def __call__(self, lan: str = "en"): + try: + import pyjokes + except ImportError: + raise ImportError("Please install pyjokes first.") + # Define the API endpoint and parameters - return " Why don't cats play poker in the jungle? Too many cheetahs!" + return pyjokes.get_joke(language=lan, category="neutral") diff --git a/project/plugins/tell_joke.yaml b/project/plugins/tell_joke.yaml index e563171e..b6aec949 100644 --- a/project/plugins/tell_joke.yaml +++ b/project/plugins/tell_joke.yaml @@ -1,14 +1,15 @@ name: tell_joke enabled: true required: false +plugin_only: true description: >- Call this plugin to tell a joke. parameters: - - name: context + - name: lan type: str - required: true - description: the context of the joke. + required: false + description: the language of the joke. Default is English. It can be en, de, es, it, gl, eu. returns: diff --git a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py index 1a9abcd6..533627c1 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py +++ b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py @@ -57,7 +57,7 @@ def __init__( self.post_translator = PostTranslator(logger) self.prompt_data = read_yaml(self.config.prompt_file_path) - self.plugin_pool = plugin_registry.get_list() + self.plugin_pool = [p for p in plugin_registry.get_list() if p.plugin_only is True] self.instruction_template = self.prompt_data["content"] if self.config.enable_auto_plugin_selection: diff --git a/taskweaver/code_interpreter/code_generator/plugin_selection.py b/taskweaver/code_interpreter/code_generator/plugin_selection.py index 4580c51f..0e52a7f4 100644 --- a/taskweaver/code_interpreter/code_generator/plugin_selection.py +++ b/taskweaver/code_interpreter/code_generator/plugin_selection.py @@ -61,17 +61,21 @@ def __init__( self, plugin_registry: PluginRegistry, llm_api: LLMApi, + plugin_only: bool = False, ): - self.plugin_registry = plugin_registry + if plugin_only: + self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True] + else: + self.available_plugins = plugin_registry.get_list() self.llm_api = llm_api self.plugin_embedding_dict: Dict[str, List[float]] = {} def generate_plugin_embeddings(self): plugin_intro_text_list: List[str] = [] - for p in self.plugin_registry.get_list(): + for p in self.available_plugins: plugin_intro_text_list.append(p.name + ": " + p.spec.description) plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list) - for i, p in enumerate(self.plugin_registry.get_list()): + for i, p in enumerate(self.available_plugins): self.plugin_embedding_dict[p.name] = plugin_embeddings[i] def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]: @@ -79,10 +83,10 @@ def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]: similarities = [] - if top_k >= len(self.plugin_registry.get_list()): - return self.plugin_registry.get_list() + if top_k >= len(self.available_plugins): + return self.available_plugins - for p in self.plugin_registry.get_list(): + for p in self.available_plugins: similarity = cosine_similarity( user_query_embedding.reshape( 1, diff --git a/taskweaver/memory/plugin.py b/taskweaver/memory/plugin.py index 80b4070e..97b44ae5 100644 --- a/taskweaver/memory/plugin.py +++ b/taskweaver/memory/plugin.py @@ -114,6 +114,7 @@ def format_return_val(val: PluginParameter) -> str: @dataclass class PluginEntry: name: str + plugin_only: bool impl: str spec: PluginSpec config: Dict[str, Any] @@ -140,6 +141,7 @@ def from_yaml_content(content: Dict) -> Optional["PluginEntry"]: config=content.get("configurations", {}), required=content.get("required", False), enabled=content.get("enabled", True), + plugin_only=content.get("plugin_only", False), ) return None @@ -157,6 +159,8 @@ def to_dict(self): } def format_function_calling(self) -> Dict: + assert self.plugin_only is True, "Only `plugin_only` plugins can be called in this way." + def map_type(t: str) -> str: if t.lower() == "string" or t.lower() == "str" or t.lower() == "text": return "string" diff --git a/taskweaver/planner/planner.py b/taskweaver/planner/planner.py index f09c26a4..80353ce2 100644 --- a/taskweaver/planner/planner.py +++ b/taskweaver/planner/planner.py @@ -68,11 +68,15 @@ def __init__( llm_api: LLMApi, plugin_registry: PluginRegistry, round_compressor: Optional[RoundCompressor] = None, + plugin_only: bool = False, ): self.config = config self.logger = logger self.llm_api = llm_api - self.plugin_registry = plugin_registry + if plugin_only: + self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True] + else: + self.available_plugins = plugin_registry.get_list() self.planner_post_translator = PostTranslator(logger) @@ -80,12 +84,12 @@ def __init__( if self.config.use_example: self.examples = self.get_examples() - if len(self.plugin_registry.get_list()) == 0: + if len(self.available_plugins) == 0: self.logger.warning("No plugin is loaded for Planner.") self.plugin_description = "No plugin functions loaded." else: self.plugin_description = "\t" + "\n\t".join( - [f"- {plugin.name}: " + f"{plugin.spec.description}" for plugin in self.plugin_registry.get_list()], + [f"- {plugin.name}: " + f"{plugin.spec.description}" for plugin in self.available_plugins], ) self.instruction_template = self.prompt_data["instruction_template"] self.code_interpreter_introduction = self.prompt_data["code_interpreter_introduction"].format( diff --git a/taskweaver/session/session.py b/taskweaver/session/session.py index 94e43928..d8c0f3ca 100644 --- a/taskweaver/session/session.py +++ b/taskweaver/session/session.py @@ -9,7 +9,7 @@ from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger from taskweaver.memory import Memory, Post, Round -from taskweaver.planner.planner import Planner, PlannerConfig +from taskweaver.planner.planner import Planner from taskweaver.workspace.workspace import Workspace @@ -48,8 +48,12 @@ def __init__( self.session_var: Dict[str, str] = {} - self.planner_config = self.session_injector.get(PlannerConfig) - self.planner = self.session_injector.get(Planner) + self.planner = self.session_injector.create_object( + Planner, + { + "plugin_only": self.config.plugin_only_mode, + }, + ) self.code_executor = self.session_injector.create_object( CodeExecutor, { diff --git a/tests/unit_tests/data/plugins/klarna_search.yaml b/tests/unit_tests/data/plugins/klarna_search.yaml index 18907092..c7b86344 100644 --- a/tests/unit_tests/data/plugins/klarna_search.yaml +++ b/tests/unit_tests/data/plugins/klarna_search.yaml @@ -1,6 +1,7 @@ name: klarna_search enabled: true required: false +plugin_only: true description: >- Search and compare prices from thousands of online shops. Only available in the US. diff --git a/tests/unit_tests/data/prompts/generator_plugin_only.yaml b/tests/unit_tests/data/prompts/generator_plugin_only.yaml new file mode 100644 index 00000000..4f69f6cc --- /dev/null +++ b/tests/unit_tests/data/prompts/generator_plugin_only.yaml @@ -0,0 +1,4 @@ +version: 0.1 +content: |- + {ROLE_NAME} can understand the user request and leverage pre-defined tools to complete tasks. + diff --git a/tests/unit_tests/test_code_generator.py b/tests/unit_tests/test_code_generator.py index 9ed57806..c88bb931 100644 --- a/tests/unit_tests/test_code_generator.py +++ b/tests/unit_tests/test_code_generator.py @@ -2,8 +2,10 @@ from injector import Injector +from taskweaver.code_interpreter.code_generator.code_generator_plugin_only import _compose_prompt from taskweaver.config.config_mgt import AppConfigSource from taskweaver.logging import LoggingModule +from taskweaver.memory.attachment import AttachmentType from taskweaver.memory.plugin import PluginModule @@ -48,21 +50,21 @@ def test_compose_prompt(): ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees the user wants generate a DataFrame.", ), ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) @@ -86,20 +88,20 @@ def test_compose_prompt(): ) post4.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} understands the user wants to find the data range for the DataFrame.", ), ) post4.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} will generate code to calculate the data range of the 'VALUE' column since it is the " "only numeric column.", ), ) post4.add_attachment( Attachment.create( - "python", + AttachmentType.python, ( "min_value = df['VALUE'].min()\n" "max_value = df['VALUE'].max()\n" @@ -112,10 +114,10 @@ def test_compose_prompt(): ), ), ) - post4.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post4.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post4.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "The minimum value in the 'VALUE' column is 0.05;The " "maximum value in the 'VALUE' column is 0.99;The " "data range for the 'VALUE' column is 0.94", @@ -246,21 +248,21 @@ def test_compose_prompt_with_plugin(): ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees the user wants generate a DataFrame.", ), ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) @@ -292,95 +294,79 @@ def test_compose_prompt_with_plugin_only(): config={ "app_dir": os.path.dirname(os.path.abspath(__file__)), "llm.api_key": "test_key", # pragma: allowlist secret - "code_generator.prompt_compression": True, + "code_generator.prompt_compression": False, "code_generator.prompt_file_path": os.path.join( os.path.dirname(os.path.abspath(__file__)), - "data/prompts/generator_prompt.yaml", + "data/prompts/generator_plugin_only.yaml", ), "plugin.base_path": os.path.join( os.path.dirname(os.path.abspath(__file__)), "data/plugins", ), - "code_generator.example_base_path": os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "data/examples/codeinterpreter_examples", - ), }, ) app_injector.binder.bind(AppConfigSource, to=app_config) - from taskweaver.code_interpreter.code_generator import CodeGenerator + from taskweaver.code_interpreter.code_generator import CodeGeneratorPluginOnly from taskweaver.memory import Attachment, Memory, Post, Round - code_generator = app_injector.get(CodeGenerator) - - code_generator.configure_verification( - code_verification_on=True, - plugin_only=True, - allowed_modules=[], - ) - code_generator.configure_verification(code_verification_on=True, plugin_only=True) + code_generator = app_injector.get(CodeGeneratorPluginOnly) - code1 = ( - "df = pd.DataFrame(np.random.rand(10, 2), columns=['DATE', 'VALUE'])\n" - 'descriptions = [("sample_code_description", "Sample code has been generated to get a dataframe `df` \n' - "with 10 rows and 2 columns: 'DATE' and 'VALUE'\")]" - ) + code1 = "r0 = klarna_search('iphone')\n" "r0" post1 = Post.create( - message="create a dataframe", + message="find iphones on sale", send_from="Planner", send_to="CodeInterpreter", attachment_list=[], ) post2 = Post.create( - message="A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", + message="The iphone 15 pro is on sale.", send_from="CodeInterpreter", send_to="Planner", attachment_list=[], ) post2.add_attachment( Attachment.create( - "thought", - "{ROLE_NAME} sees the user wants generate a DataFrame.", + AttachmentType.thought, + "{ROLE_NAME} sees the user wants to find iphones on sale.", ), ) post2.add_attachment( Attachment.create( - "thought", - "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", + AttachmentType.thought, + "{ROLE_NAME} can use the `klarna_search` function to find iphones on sale.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) - round1 = Round.create(user_query="hello", id="round-1") + round1 = Round.create(user_query="find iphones on sale", id="round-1") round1.add_post(post1) round1.add_post(post2) memory = Memory(session_id="session-1") memory.conversation.add_round(round1) - messages = code_generator.compose_prompt( + messages, functions = _compose_prompt( + system_instructions=code_generator.instruction_template.format( + ROLE_NAME=code_generator.role_name, + ), rounds=memory.conversation.rounds, - plugins=code_generator.get_plugin_pool(), + plugin_pool=code_generator.plugin_pool, ) - assert "read_csv" in messages[1]["content"] - assert "write_csv" in messages[1]["content"] - assert "This is the feedback" in messages[3]["content"] - assert "Execution" in messages[3]["content"] - assert "Verification" in messages[3]["content"] - - assert "sql_pull_data" in messages[4]["content"] - assert "anomaly_detection" in messages[4]["content"] - assert "klarna_search" in messages[4]["content"] - assert "paper_summary" in messages[4]["content"] + assert len(functions) == 1 + assert functions[0]["function"]["name"] == "klarna_search" + assert messages[1]["role"] == "user" + assert messages[1]["content"] == "find iphones on sale" + assert messages[2]["role"] == "assistant" + assert messages[2]["content"] == "The iphone 15 pro is on sale." def test_compose_prompt_with_not_plugin_only(): @@ -432,21 +418,21 @@ def test_compose_prompt_with_not_plugin_only(): ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees the user wants generate a DataFrame.", ), ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) @@ -463,17 +449,11 @@ def test_compose_prompt_with_not_plugin_only(): plugins=code_generator.get_plugin_pool(), ) - assert "read_csv" not in messages[1]["content"] - assert "write_csv" not in messages[1]["content"] - assert "sql_pull_data" not in messages[1]["content"] - assert "anomaly_detection" not in messages[1]["content"] - assert "klarna_search" not in messages[1]["content"] - assert "paper_summary" not in messages[1]["content"] - - assert "sql_pull_data" in messages[13]["content"] - assert "anomaly_detection" in messages[13]["content"] - assert "klarna_search" in messages[13]["content"] - assert "paper_summary" in messages[13]["content"] + assert len(code_generator.plugin_pool) == 4 + assert "anomaly_detection" in messages[16]["content"] + assert "klarna_search" in messages[16]["content"] + assert "paper_summary" in messages[16]["content"] + assert "sql_pull_data" in messages[16]["content"] def test_code_correction_prompt(): diff --git a/tests/unit_tests/test_function_calling.py b/tests/unit_tests/test_function_calling.py index a719788d..a8701f17 100644 --- a/tests/unit_tests/test_function_calling.py +++ b/tests/unit_tests/test_function_calling.py @@ -44,6 +44,7 @@ def test_function_formatting(): config={"test_key": "test_val"}, required=False, enabled=True, + plugin_only=True, ) assert plugin.format_function_calling() == { "type": "function", From 388cb0db36c3c97ab670473e75761567c872cb1b Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Mon, 25 Dec 2023 12:25:18 +0800 Subject: [PATCH 10/13] remove args from doc --- taskweaver/misc/example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/taskweaver/misc/example.py b/taskweaver/misc/example.py index de70b60d..b006c5af 100644 --- a/taskweaver/misc/example.py +++ b/taskweaver/misc/example.py @@ -11,7 +11,6 @@ def load_examples(folder: str) -> List[Conversation]: Args: folder: the folder path. - plugin_only: whether to load only the plugin examples. """ example_file_list: List[str] = glob.glob(path.join(folder, "*.yaml")) example_conv_pool: List[Conversation] = [] From be81d51d03543a66b10f1ec9708e27a70f22c81c Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Mon, 25 Dec 2023 12:33:45 +0800 Subject: [PATCH 11/13] remove function define args --- taskweaver/llm/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/taskweaver/llm/base.py b/taskweaver/llm/base.py index 2d5becb5..24a9781c 100644 --- a/taskweaver/llm/base.py +++ b/taskweaver/llm/base.py @@ -72,7 +72,6 @@ def chat_completion( max_tokens: Optional[int] = None, top_p: Optional[float] = None, stop: Optional[List[str]] = None, - tools: Optional[List] = None, **kwargs: Any, ) -> Generator[ChatMessageType, None, None]: """ From a5a618a075578e986d6659f3279e874d10d76fee Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Mon, 25 Dec 2023 17:41:49 +0800 Subject: [PATCH 12/13] optimize prompt for installing packages --- project/plugins/anomaly_detection.yaml | 1 + project/plugins/ascii_render.yaml | 1 + project/plugins/klarna_search.yaml | 1 + project/plugins/paper_summary.yaml | 1 + project/plugins/sql_pull_data.yaml | 1 + project/plugins/tell_joke.yaml | 2 +- .../code_interpreter/code_generator/code_generator_prompt.yaml | 1 + taskweaver/planner/planner_prompt.yaml | 1 + 8 files changed, 8 insertions(+), 1 deletion(-) diff --git a/project/plugins/anomaly_detection.yaml b/project/plugins/anomaly_detection.yaml index 29c68cdc..378da913 100644 --- a/project/plugins/anomaly_detection.yaml +++ b/project/plugins/anomaly_detection.yaml @@ -4,6 +4,7 @@ required: false description: >- anomaly_detection function identifies anomalies from an input DataFrame of time series. It will add a new column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" otherwise. + For example, result_df = anomaly_detection(df, "datetime", "value"). parameters: - name: df diff --git a/project/plugins/ascii_render.yaml b/project/plugins/ascii_render.yaml index 2e940c98..395c3dd7 100644 --- a/project/plugins/ascii_render.yaml +++ b/project/plugins/ascii_render.yaml @@ -4,6 +4,7 @@ required: true plugin_only: true description: >- This plugin renders the input text into ASCII art form. The input should be a string and the output is also a string in ASCII art. + For example, result = ascii_render("Hello World!"). parameters: - name: text diff --git a/project/plugins/klarna_search.yaml b/project/plugins/klarna_search.yaml index ef0d1c9b..d491271c 100644 --- a/project/plugins/klarna_search.yaml +++ b/project/plugins/klarna_search.yaml @@ -6,6 +6,7 @@ description: >- Search and compare prices from thousands of online shops. Only available in the US. This plugin only takes user requests when searching for merchandise. If not clear, confirm with the user if they want to search for merchandise from Klarna. + For example, result = klarna_search("laptop", 10, 1000, 2000). parameters: - name: query diff --git a/project/plugins/paper_summary.yaml b/project/plugins/paper_summary.yaml index 4a07306a..05f81d5a 100644 --- a/project/plugins/paper_summary.yaml +++ b/project/plugins/paper_summary.yaml @@ -5,6 +5,7 @@ description: >- summarize_paper function iteratively summarizes a given paper page by page, highlighting the key points, including the problem, main idea, contributions, experiments, results, and conclusions. + For example, result = summarize_paper("paper.pdf"). parameters: - name: paper_file_path diff --git a/project/plugins/sql_pull_data.yaml b/project/plugins/sql_pull_data.yaml index e2b343a3..d68705ef 100644 --- a/project/plugins/sql_pull_data.yaml +++ b/project/plugins/sql_pull_data.yaml @@ -6,6 +6,7 @@ description: >- This plugin takes user requests when obtaining data from database is explicitly mentioned. Otherwise, confirm with the user if they want to pull data from this database. The data from this database can only used for anomaly detection. + For example, df = sql_pull_data("pull data from time_series table"). parameters: - name: query diff --git a/project/plugins/tell_joke.yaml b/project/plugins/tell_joke.yaml index b6aec949..bfccd4c2 100644 --- a/project/plugins/tell_joke.yaml +++ b/project/plugins/tell_joke.yaml @@ -3,7 +3,7 @@ enabled: true required: false plugin_only: true description: >- - Call this plugin to tell a joke. + Call this plugin to tell a joke. For example, result = tell_joke("en"). parameters: - name: lan diff --git a/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml b/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml index cda67f5c..80d33ee1 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml +++ b/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml @@ -58,4 +58,5 @@ requirements: |- - {ROLE_NAME} should not refer to any information from failed rounds, rounds that have not been executed, or previous Conversations. - {ROLE_NAME} put all the result variables in the last line of the code. - {ROLE_NAME} must not import the plugins and otherwise the code will be failed to execute. + - {ROLE_NAME} must try to directly import required modules without installing them, and only install the modules if the execution fails. {CODE_GENERATION_REQUIREMENTS} diff --git a/taskweaver/planner/planner_prompt.yaml b/taskweaver/planner/planner_prompt.yaml index f3a9d7f0..a3bfbd37 100644 --- a/taskweaver/planner/planner_prompt.yaml +++ b/taskweaver/planner/planner_prompt.yaml @@ -31,6 +31,7 @@ instruction_template: |- - If User has made any changes to the environment, Planner should inform CodeInterpreter accordingly. - Planner can ignore the permission or data access issues because CodeInterpreter can handle this kind of problem. - Planner must include 2 parts: description of the User's request and the current step that the Planner is executing. + - Planner must not ask CodeInterpreter to install any packages unless the User explicitly requests to do so. ## Planner's response format - Planner must strictly format the response into the following JSON object: From eef2cb83afc1777fc289fd1fc44fbe29e5924925 Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Mon, 25 Dec 2023 17:49:30 +0800 Subject: [PATCH 13/13] correct examples --- project/plugins/anomaly_detection.yaml | 2 +- project/plugins/klarna_search.yaml | 2 +- project/plugins/paper_summary.yaml | 2 +- project/plugins/sql_pull_data.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/project/plugins/anomaly_detection.yaml b/project/plugins/anomaly_detection.yaml index 378da913..06fa8a56 100644 --- a/project/plugins/anomaly_detection.yaml +++ b/project/plugins/anomaly_detection.yaml @@ -4,7 +4,7 @@ required: false description: >- anomaly_detection function identifies anomalies from an input DataFrame of time series. It will add a new column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" otherwise. - For example, result_df = anomaly_detection(df, "datetime", "value"). + For example, result_df, description = anomaly_detection(df, "datetime", "value"). parameters: - name: df diff --git a/project/plugins/klarna_search.yaml b/project/plugins/klarna_search.yaml index d491271c..9aef2920 100644 --- a/project/plugins/klarna_search.yaml +++ b/project/plugins/klarna_search.yaml @@ -6,7 +6,7 @@ description: >- Search and compare prices from thousands of online shops. Only available in the US. This plugin only takes user requests when searching for merchandise. If not clear, confirm with the user if they want to search for merchandise from Klarna. - For example, result = klarna_search("laptop", 10, 1000, 2000). + For example, result, description = klarna_search("laptop", 10, 1000, 2000). parameters: - name: query diff --git a/project/plugins/paper_summary.yaml b/project/plugins/paper_summary.yaml index 05f81d5a..523c2040 100644 --- a/project/plugins/paper_summary.yaml +++ b/project/plugins/paper_summary.yaml @@ -5,7 +5,7 @@ description: >- summarize_paper function iteratively summarizes a given paper page by page, highlighting the key points, including the problem, main idea, contributions, experiments, results, and conclusions. - For example, result = summarize_paper("paper.pdf"). + For example, result, description = summarize_paper("paper.pdf"). parameters: - name: paper_file_path diff --git a/project/plugins/sql_pull_data.yaml b/project/plugins/sql_pull_data.yaml index d68705ef..d756f958 100644 --- a/project/plugins/sql_pull_data.yaml +++ b/project/plugins/sql_pull_data.yaml @@ -6,7 +6,7 @@ description: >- This plugin takes user requests when obtaining data from database is explicitly mentioned. Otherwise, confirm with the user if they want to pull data from this database. The data from this database can only used for anomaly detection. - For example, df = sql_pull_data("pull data from time_series table"). + For example, df, description = sql_pull_data("pull data from time_series table"). parameters: - name: query