diff --git a/project/codeinterpreter_examples/example1-codeinterpreter.yaml b/project/codeinterpreter_examples/example1-codeinterpreter.yaml index 97a7c717..21e4a1f6 100644 --- a/project/codeinterpreter_examples/example1-codeinterpreter.yaml +++ b/project/codeinterpreter_examples/example1-codeinterpreter.yaml @@ -1,5 +1,4 @@ enabled: True -plugin_only: False rounds: - user_query: hello state: finished diff --git a/project/codeinterpreter_examples/example2-codeinterpreter.yaml b/project/codeinterpreter_examples/example2-codeinterpreter.yaml index e17108c7..f2d77f8f 100644 --- a/project/codeinterpreter_examples/example2-codeinterpreter.yaml +++ b/project/codeinterpreter_examples/example2-codeinterpreter.yaml @@ -1,5 +1,4 @@ enabled: True -plugin_only: False rounds: - user_query: read file /abc/def.txt state: finished diff --git a/project/codeinterpreter_examples/example3-codeinterpreter.yaml b/project/codeinterpreter_examples/example3-codeinterpreter.yaml deleted file mode 100644 index 5f6a53af..00000000 --- a/project/codeinterpreter_examples/example3-codeinterpreter.yaml +++ /dev/null @@ -1,70 +0,0 @@ -enabled: True -plugin_only: True -plugins: - - name: read_csv - enabled: true - required: false - description: read the content of a csv file - - parameters: - - name: file_path - type: string - required: true - description: the path of the file - - returns: - - name: df - type: DataFrame - description: This DataFrame contains the content of the csv file. - - name: description - type: str - description: This is a string describing the csv schema. - - - name: write_csv - enabled: true - required: false - description: write the content of a DataFrame to a csv file - - parameters: - - name: df - type: DataFrame - required: true - description: the DataFrame to be written to the csv file - - name: file_path - type: string - required: true - description: the path of the file - - returns: - - name: description - type: str - description: This is a string describing success or failure of the write operation. -rounds: - - user_query: read file /abc/def.csv and write to /abc/backup.csv - state: finished - post_list: - - message: read file /abc/def.csv and write to /abc/backup.csv. You can only use the pre-defined plugins. - send_from: Planner - send_to: CodeInterpreter - attachment_list: [] - - message: I have read the file /abc/def.csv and written the content to /abc/backup.csv. - send_from: CodeInterpreter - send_to: Planner - attachment_list: - - type: thought - content: "{ROLE_NAME} will generate a code snippet to read the file /abc/def.csv and add 1 to each value in column \"value\"." - - type: thought - content: "{ROLE_NAME} is prohibited to generate any code other than variable assignments and plugin calls." - - type: python - content: |- - df = read_csv("/abc/def.csv") - status = write_csv(df, "/abc/backup.csv") - status - - type: verification - content: CORRECT - - type: code_error - content: No code error. - - type: execution_status - content: SUCCESS - - type: execution_result - content: The file /abc/def.csv has been read and the content has been written to /abc/backup.csv. diff --git a/project/plugins/anomaly_detection.yaml b/project/plugins/anomaly_detection.yaml index 29c68cdc..06fa8a56 100644 --- a/project/plugins/anomaly_detection.yaml +++ b/project/plugins/anomaly_detection.yaml @@ -4,6 +4,7 @@ required: false description: >- anomaly_detection function identifies anomalies from an input DataFrame of time series. It will add a new column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" otherwise. + For example, result_df, description = anomaly_detection(df, "datetime", "value"). parameters: - name: df diff --git a/project/plugins/ascii_render.py b/project/plugins/ascii_render.py new file mode 100644 index 00000000..dfd52c78 --- /dev/null +++ b/project/plugins/ascii_render.py @@ -0,0 +1,15 @@ +from taskweaver.plugin import Plugin, register_plugin + + +@register_plugin +class AsciiRenderPlugin(Plugin): + def __call__(self, text: str): + try: + import pyfiglet + except ImportError: + raise ImportError("Please install pyfiglet first.") + + ASCII_art_1 = pyfiglet.figlet_format(text, font="isometric1") + result = ASCII_art_1 + + return result diff --git a/project/plugins/ascii_render.yaml b/project/plugins/ascii_render.yaml new file mode 100644 index 00000000..395c3dd7 --- /dev/null +++ b/project/plugins/ascii_render.yaml @@ -0,0 +1,20 @@ +name: ascii_render +enabled: true +required: true +plugin_only: true +description: >- + This plugin renders the input text into ASCII art form. The input should be a string and the output is also a string in ASCII art. + For example, result = ascii_render("Hello World!"). + +parameters: + - name: text + type: str + required: true + description: >- + This is the input text to be rendered into ASCII art form. + +returns: + - name: result + type: str + description: >- + The rendered text in ASCII art. \ No newline at end of file diff --git a/project/plugins/klarna_search.yaml b/project/plugins/klarna_search.yaml index 187e8c1e..9aef2920 100644 --- a/project/plugins/klarna_search.yaml +++ b/project/plugins/klarna_search.yaml @@ -1,10 +1,12 @@ name: klarna_search enabled: true required: false +plugin_only: true description: >- Search and compare prices from thousands of online shops. Only available in the US. This plugin only takes user requests when searching for merchandise. If not clear, confirm with the user if they want to search for merchandise from Klarna. + For example, result, description = klarna_search("laptop", 10, 1000, 2000). parameters: - name: query diff --git a/project/plugins/paper_summary.yaml b/project/plugins/paper_summary.yaml index 4a07306a..523c2040 100644 --- a/project/plugins/paper_summary.yaml +++ b/project/plugins/paper_summary.yaml @@ -5,6 +5,7 @@ description: >- summarize_paper function iteratively summarizes a given paper page by page, highlighting the key points, including the problem, main idea, contributions, experiments, results, and conclusions. + For example, result, description = summarize_paper("paper.pdf"). parameters: - name: paper_file_path diff --git a/project/plugins/sql_pull_data.yaml b/project/plugins/sql_pull_data.yaml index e2b343a3..d756f958 100644 --- a/project/plugins/sql_pull_data.yaml +++ b/project/plugins/sql_pull_data.yaml @@ -6,6 +6,7 @@ description: >- This plugin takes user requests when obtaining data from database is explicitly mentioned. Otherwise, confirm with the user if they want to pull data from this database. The data from this database can only used for anomaly detection. + For example, df, description = sql_pull_data("pull data from time_series table"). parameters: - name: query diff --git a/project/plugins/tell_joke.py b/project/plugins/tell_joke.py new file mode 100644 index 00000000..5207a80a --- /dev/null +++ b/project/plugins/tell_joke.py @@ -0,0 +1,13 @@ +from taskweaver.plugin import Plugin, register_plugin + + +@register_plugin +class TellJoke(Plugin): + def __call__(self, lan: str = "en"): + try: + import pyjokes + except ImportError: + raise ImportError("Please install pyjokes first.") + + # Define the API endpoint and parameters + return pyjokes.get_joke(language=lan, category="neutral") diff --git a/project/plugins/tell_joke.yaml b/project/plugins/tell_joke.yaml new file mode 100644 index 00000000..bfccd4c2 --- /dev/null +++ b/project/plugins/tell_joke.yaml @@ -0,0 +1,18 @@ +name: tell_joke +enabled: true +required: false +plugin_only: true +description: >- + Call this plugin to tell a joke. For example, result = tell_joke("en"). + +parameters: + - name: lan + type: str + required: false + description: the language of the joke. Default is English. It can be en, de, es, it, gl, eu. + + +returns: + - name: joke + type: str + description: the joke. diff --git a/taskweaver/code_interpreter/__init__.py b/taskweaver/code_interpreter/__init__.py index 0e08e476..34ef27d8 100644 --- a/taskweaver/code_interpreter/__init__.py +++ b/taskweaver/code_interpreter/__init__.py @@ -1 +1,2 @@ -from .code_interpreter import CodeInterpreter, CodeInterpreterConfig +from .code_interpreter import CodeInterpreter +from .code_interpreter_plugin_only import CodeInterpreterPluginOnly diff --git a/taskweaver/code_interpreter/code_generator/__init__.py b/taskweaver/code_interpreter/code_generator/__init__.py index 2e59bad0..c415448c 100644 --- a/taskweaver/code_interpreter/code_generator/__init__.py +++ b/taskweaver/code_interpreter/code_generator/__init__.py @@ -1 +1,2 @@ from .code_generator import CodeGenerator, CodeGeneratorConfig, format_code_revision_message +from .code_generator_plugin_only import CodeGeneratorPluginOnly diff --git a/taskweaver/code_interpreter/code_generator/code_generator.py b/taskweaver/code_interpreter/code_generator/code_generator.py index 878d47d9..c11d15e7 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator.py +++ b/taskweaver/code_interpreter/code_generator/code_generator.py @@ -80,7 +80,6 @@ def __init__( self.examples = None self.code_verification_on: bool = False self.allowed_modules: List[str] = [] - self.plugin_only: bool = False self.instruction = self.instruction_template.format( ROLE_NAME=self.role_name, @@ -98,10 +97,8 @@ def __init__( def configure_verification( self, code_verification_on: bool, - plugin_only: bool, allowed_modules: Optional[List[str]] = None, ): - self.plugin_only = plugin_only self.allowed_modules = allowed_modules if allowed_modules is not None else [] self.code_verification_on = code_verification_on @@ -113,23 +110,13 @@ def compose_verification_requirements( if not self.code_verification_on: return "" - if self.plugin_only: - requirements.append( - f"- {self.role_name} should only use the following plugins and" - + " Python built-in functions to complete the task: " - + ", ".join([f"{plugin.name}" for plugin in plugin_list]), - ) - requirements.append( - f"- {self.role_name} cannot define new functions or plugins.", - ) - if len(self.allowed_modules) > 0: requirements.append( f"- {self.role_name} can only import the following Python modules: " + ", ".join([f"{module}" for module in self.allowed_modules]), ) - if len(self.allowed_modules) == 0 and self.plugin_only: + if len(self.allowed_modules) == 0: requirements.append(f"- {self.role_name} cannot import any Python modules.") return "\n".join(requirements) @@ -141,7 +128,7 @@ def compose_prompt( chat_history = [format_chat_message(role="system", message=self.instruction)] if self.examples is None: - self.examples = self.load_examples(plugin_only=self.plugin_only) + self.examples = self.load_examples() for i, example in enumerate(self.examples): chat_history.extend( self.compose_conversation(example.rounds, example.plugins, add_requirements=False), @@ -366,12 +353,10 @@ def format_plugins( def load_examples( self, - plugin_only: bool, ) -> List[Conversation]: if self.config.load_example: return load_examples( folder=self.config.example_base_path, - plugin_only=plugin_only, ) return [] diff --git a/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py new file mode 100644 index 00000000..533627c1 --- /dev/null +++ b/taskweaver/code_interpreter/code_generator/code_generator_plugin_only.py @@ -0,0 +1,141 @@ +import os +from typing import List, Tuple + +from injector import inject + +from taskweaver.code_interpreter.code_generator.plugin_selection import PluginSelector, SelectedPluginPool +from taskweaver.config.module_config import ModuleConfig +from taskweaver.llm import LLMApi, format_chat_message +from taskweaver.logging import TelemetryLogger +from taskweaver.memory import Attachment, Memory, Post, Round +from taskweaver.memory.attachment import AttachmentType +from taskweaver.memory.plugin import PluginEntry, PluginRegistry +from taskweaver.role import PostTranslator, Role +from taskweaver.utils import read_yaml + + +class CodeGeneratorPluginOnlyConfig(ModuleConfig): + def _configure(self) -> None: + self._set_name("code_generator") + self.role_name = self._get_str("role_name", "ProgramApe") + + self.prompt_file_path = self._get_path( + "prompt_file_path", + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "code_generator_prompt_plugin_only.yaml", + ), + ) + self.prompt_compression = self._get_bool("prompt_compression", False) + assert self.prompt_compression is False, "Compression is not supported for plugin only mode." + + self.compression_prompt_path = self._get_path( + "compression_prompt_path", + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "compression_prompt.yaml", + ), + ) + self.enable_auto_plugin_selection = self._get_bool("enable_auto_plugin_selection", False) + self.auto_plugin_selection_topk = self._get_int("auto_plugin_selection_topk", 3) + + +class CodeGeneratorPluginOnly(Role): + @inject + def __init__( + self, + config: CodeGeneratorPluginOnlyConfig, + plugin_registry: PluginRegistry, + logger: TelemetryLogger, + llm_api: LLMApi, + ): + self.config = config + self.logger = logger + self.llm_api = llm_api + + self.role_name = self.config.role_name + + self.post_translator = PostTranslator(logger) + self.prompt_data = read_yaml(self.config.prompt_file_path) + self.plugin_pool = [p for p in plugin_registry.get_list() if p.plugin_only is True] + self.instruction_template = self.prompt_data["content"] + + if self.config.enable_auto_plugin_selection: + self.plugin_selector = PluginSelector(plugin_registry, self.llm_api) + self.plugin_selector.generate_plugin_embeddings() + logger.info("Plugin embeddings generated") + self.selected_plugin_pool = SelectedPluginPool() + + def select_plugins_for_prompt( + self, + user_query, + ) -> List[PluginEntry]: + selected_plugins = self.plugin_selector.plugin_select( + user_query, + self.config.auto_plugin_selection_topk, + ) + self.selected_plugin_pool.add_selected_plugins(selected_plugins) + self.logger.info(f"Selected plugins: {[p.name for p in selected_plugins]}") + self.logger.info(f"Selected plugin pool: {[p.name for p in self.selected_plugin_pool.get_plugins()]}") + + return self.selected_plugin_pool.get_plugins() + + def reply(self, memory: Memory, event_handler: callable) -> Post: + # extract all rounds from memory + rounds = memory.get_role_rounds( + role="CodeInterpreter", + include_failure_rounds=False, + ) + + user_query = rounds[-1].user_query + if self.config.enable_auto_plugin_selection: + self.plugin_pool = self.select_plugins_for_prompt(user_query) + + # obtain the user query from the last round + prompt, tools = _compose_prompt( + system_instructions=self.instruction_template.format( + ROLE_NAME=self.role_name, + ), + rounds=rounds, + plugin_pool=self.plugin_pool, + ) + post = Post.create(message=None, send_from="CodeInterpreter", send_to="Planner") + + llm_response = self.llm_api.chat_completion( + messages=prompt, + tools=tools, + tool_choice="auto", + response_format=None, + stream=False, + ) + if llm_response["role"] == "assistant": + post.message = llm_response["content"] + event_handler("CodeInterpreter->Planner", post.message) + return post + elif llm_response["role"] == "function": + post.add_attachment(Attachment.create(type=AttachmentType.function, content=llm_response["content"])) + event_handler("function", llm_response["content"]) + + if self.config.enable_auto_plugin_selection: + # here the code is in json format, not really code + self.selected_plugin_pool.filter_unused_plugins(code=llm_response["content"]) + return post + else: + raise ValueError(f"Unexpected response from LLM: {llm_response}") + + +def _compose_prompt( + system_instructions: str, + rounds: List[Round], + plugin_pool: List[PluginEntry], +) -> Tuple[List, List]: + functions = [plugin.format_function_calling() for plugin in plugin_pool] + prompt = [format_chat_message(role="system", message=system_instructions)] + for _round in rounds: + for post in _round.post_list: + if post.send_from == "Planner" and post.send_to == "CodeInterpreter": + prompt.append(format_chat_message(role="user", message=post.message)) + elif post.send_from == "CodeInterpreter" and post.send_to == "Planner": + prompt.append(format_chat_message(role="assistant", message=post.message)) + + return prompt, functions diff --git a/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml b/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml index cda67f5c..80d33ee1 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml +++ b/taskweaver/code_interpreter/code_generator/code_generator_prompt.yaml @@ -58,4 +58,5 @@ requirements: |- - {ROLE_NAME} should not refer to any information from failed rounds, rounds that have not been executed, or previous Conversations. - {ROLE_NAME} put all the result variables in the last line of the code. - {ROLE_NAME} must not import the plugins and otherwise the code will be failed to execute. + - {ROLE_NAME} must try to directly import required modules without installing them, and only install the modules if the execution fails. {CODE_GENERATION_REQUIREMENTS} diff --git a/taskweaver/code_interpreter/code_generator/code_generator_prompt_plugin_only.yaml b/taskweaver/code_interpreter/code_generator/code_generator_prompt_plugin_only.yaml new file mode 100644 index 00000000..4f69f6cc --- /dev/null +++ b/taskweaver/code_interpreter/code_generator/code_generator_prompt_plugin_only.yaml @@ -0,0 +1,4 @@ +version: 0.1 +content: |- + {ROLE_NAME} can understand the user request and leverage pre-defined tools to complete tasks. + diff --git a/taskweaver/code_interpreter/code_generator/plugin_selection.py b/taskweaver/code_interpreter/code_generator/plugin_selection.py index 4580c51f..0e52a7f4 100644 --- a/taskweaver/code_interpreter/code_generator/plugin_selection.py +++ b/taskweaver/code_interpreter/code_generator/plugin_selection.py @@ -61,17 +61,21 @@ def __init__( self, plugin_registry: PluginRegistry, llm_api: LLMApi, + plugin_only: bool = False, ): - self.plugin_registry = plugin_registry + if plugin_only: + self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True] + else: + self.available_plugins = plugin_registry.get_list() self.llm_api = llm_api self.plugin_embedding_dict: Dict[str, List[float]] = {} def generate_plugin_embeddings(self): plugin_intro_text_list: List[str] = [] - for p in self.plugin_registry.get_list(): + for p in self.available_plugins: plugin_intro_text_list.append(p.name + ": " + p.spec.description) plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list) - for i, p in enumerate(self.plugin_registry.get_list()): + for i, p in enumerate(self.available_plugins): self.plugin_embedding_dict[p.name] = plugin_embeddings[i] def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]: @@ -79,10 +83,10 @@ def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]: similarities = [] - if top_k >= len(self.plugin_registry.get_list()): - return self.plugin_registry.get_list() + if top_k >= len(self.available_plugins): + return self.available_plugins - for p in self.plugin_registry.get_list(): + for p in self.available_plugins: similarity = cosine_similarity( user_query_embedding.reshape( 1, diff --git a/taskweaver/code_interpreter/code_interpreter.py b/taskweaver/code_interpreter/code_interpreter.py index f649c554..7153a61e 100644 --- a/taskweaver/code_interpreter/code_interpreter.py +++ b/taskweaver/code_interpreter/code_interpreter.py @@ -22,16 +22,11 @@ def _configure(self): # for verification self.code_verification_on = self._get_bool("code_verification_on", False) - self.plugin_only = self._get_bool("plugin_only", False) self.allowed_modules = self._get_list( "allowed_modules", ["pandas", "matplotlib", "numpy", "sklearn", "scipy", "seaborn", "datetime", "typing"], ) - if self.plugin_only: - self.code_verification_on = True - self.allowed_modules = [] - def update_verification( response: Post, @@ -69,7 +64,6 @@ def __init__( self.generator = generator self.generator.configure_verification( code_verification_on=self.config.code_verification_on, - plugin_only=self.config.plugin_only, allowed_modules=self.config.allowed_modules, ) @@ -130,8 +124,8 @@ def reply( code.content, [plugin.name for plugin in self.generator.get_plugin_pool()], self.config.code_verification_on, - self.config.plugin_only, - self.config.allowed_modules, + plugin_only=False, + allowed_modules=self.config.allowed_modules, ) if code_verify_errors is None: diff --git a/taskweaver/code_interpreter/code_interpreter_plugin_only.py b/taskweaver/code_interpreter/code_interpreter_plugin_only.py new file mode 100644 index 00000000..658e72af --- /dev/null +++ b/taskweaver/code_interpreter/code_interpreter_plugin_only.py @@ -0,0 +1,91 @@ +import json +from typing import Optional + +from injector import inject + +from taskweaver.code_interpreter.code_executor import CodeExecutor +from taskweaver.code_interpreter.code_generator import CodeGeneratorPluginOnly +from taskweaver.config.module_config import ModuleConfig +from taskweaver.logging import TelemetryLogger +from taskweaver.memory import Memory, Post +from taskweaver.memory.attachment import AttachmentType +from taskweaver.role import Role + + +class CodeInterpreterConfig(ModuleConfig): + def _configure(self): + self._set_name("code_interpreter_plugin_only") + self.use_local_uri = self._get_bool("use_local_uri", False) + self.max_retry_count = self._get_int("max_retry_count", 3) + + +class CodeInterpreterPluginOnly(Role): + @inject + def __init__( + self, + generator: CodeGeneratorPluginOnly, + executor: CodeExecutor, + logger: TelemetryLogger, + config: CodeInterpreterConfig, + ): + self.generator = generator + self.executor = executor + self.logger = logger + self.config = config + self.retry_count = 0 + self.return_index = 0 + + self.logger.info("CodeInterpreter initialized successfully.") + + def reply( + self, + memory: Memory, + event_handler: callable, + prompt_log_path: Optional[str] = None, + use_back_up_engine: Optional[bool] = False, + ) -> Post: + response: Post = self.generator.reply( + memory, + event_handler, + ) + + if response.message is not None: + return response + + functions = json.loads(response.get_attachment(type=AttachmentType.function)[0]) + if len(functions) > 0: + code = [] + for i, f in enumerate(functions): + function_name = f["name"] + function_args = json.loads(f["arguments"]) + function_call = ( + f"r{self.return_index + i}={function_name}(" + + ", ".join( + [ + f'{key}="{value}"' if isinstance(value, str) else f"{key}={value}" + for key, value in function_args.items() + ], + ) + + ")" + ) + code.append(function_call) + code.append(f'{", ".join([f"r{self.return_index + i}" for i in range(len(functions))])}') + self.return_index += len(functions) + + event_handler("code", "\n".join(code)) + exec_result = self.executor.execute_code( + exec_id=response.id, + code="\n".join(code), + ) + + response.message = self.executor.format_code_output( + exec_result, + with_code=True, + use_local_uri=self.config.use_local_uri, + ) + event_handler("CodeInterpreter-> Planner", response.message) + else: + response.message = "No code is generated because no function is selected." + event_handler("CodeInterpreter-> Planner", response.message) + + return response diff --git a/taskweaver/llm/__init__.py b/taskweaver/llm/__init__.py index 36b19e45..03167d51 100644 --- a/taskweaver/llm/__init__.py +++ b/taskweaver/llm/__init__.py @@ -99,6 +99,8 @@ def chat_completion( ): msg["role"] = msg_chunk["role"] msg["content"] += msg_chunk["content"] + if "name" in msg_chunk: + msg["name"] = msg_chunk["name"] return msg def chat_completion_stream( diff --git a/taskweaver/llm/openai.py b/taskweaver/llm/openai.py index ae8119ec..4e1d3e2a 100644 --- a/taskweaver/llm/openai.py +++ b/taskweaver/llm/openai.py @@ -148,6 +148,18 @@ def chat_completion( try: if use_backup_engine: engine = backup_engine + + tools_kwargs = {} + if "tools" in kwargs and "tool_choice" in kwargs: + tools_kwargs["tools"] = kwargs["tools"] + tools_kwargs["tool_choice"] = kwargs["tool_choice"] + if "response_format" in kwargs: + response_format = kwargs["response_format"] + elif self.config.response_format == "json_object": + response_format = {"type": "json_object"} + else: + response_format = None + res: Any = self.client.chat.completions.create( model=engine, messages=messages, # type: ignore @@ -159,9 +171,8 @@ def chat_completion( stop=stop, stream=stream, seed=seed, - response_format=( - {"type": "json_object"} if self.config.response_format == "json_object" else None # type: ignore - ), + response_format=response_format, + **tools_kwargs, ) if stream: role: Any = None @@ -185,6 +196,15 @@ def chat_completion( role=oai_response.role if oai_response.role is not None else "assistant", message=oai_response.content if oai_response.content is not None else "", ) + if oai_response.tool_calls is not None: + response["role"] = "function" + response["content"] = ( + "[" + + ",".join( + [t.function.model_dump_json() for t in oai_response.tool_calls], + ) + + "]" + ) yield response except openai.APITimeoutError as e: diff --git a/taskweaver/llm/util.py b/taskweaver/llm/util.py index 3610bb5d..dc7a346c 100644 --- a/taskweaver/llm/util.py +++ b/taskweaver/llm/util.py @@ -1,6 +1,6 @@ from typing import Dict, Literal, Optional -ChatMessageRoleType = Literal["system", "user", "assistant"] +ChatMessageRoleType = Literal["system", "user", "assistant", "function"] ChatMessageType = Dict[Literal["role", "name", "content"], str] diff --git a/taskweaver/memory/attachment.py b/taskweaver/memory/attachment.py index 4c3f658a..941253a1 100644 --- a/taskweaver/memory/attachment.py +++ b/taskweaver/memory/attachment.py @@ -35,6 +35,9 @@ class AttachmentType(Enum): # CodeInterpreter - revise code revise_message = "revise_message" + # function calling + function = "function" + # Misc invalid_response = "invalid_response" diff --git a/taskweaver/memory/plugin.py b/taskweaver/memory/plugin.py index 386094a4..97b44ae5 100644 --- a/taskweaver/memory/plugin.py +++ b/taskweaver/memory/plugin.py @@ -114,6 +114,7 @@ def format_return_val(val: PluginParameter) -> str: @dataclass class PluginEntry: name: str + plugin_only: bool impl: str spec: PluginSpec config: Dict[str, Any] @@ -140,6 +141,7 @@ def from_yaml_content(content: Dict) -> Optional["PluginEntry"]: config=content.get("configurations", {}), required=content.get("required", False), enabled=content.get("enabled", True), + plugin_only=content.get("plugin_only", False), ) return None @@ -156,6 +158,38 @@ def to_dict(self): "enabled": self.enabled, } + def format_function_calling(self) -> Dict: + assert self.plugin_only is True, "Only `plugin_only` plugins can be called in this way." + + def map_type(t: str) -> str: + if t.lower() == "string" or t.lower() == "str" or t.lower() == "text": + return "string" + if t.lower() == "integer" or t.lower() == "int": + return "integer" + if t.lower() == "float" or t.lower() == "double" or t.lower() == "number": + return "number" + if t.lower() == "boolean" or t.lower() == "bool": + return "boolean" + if t.lower() == "null" or t.lower() == "none": + return "null" + raise Exception(f"unknown type {t}") + + function = {"type": "function", "function": {}} + required_params = [] + function["function"]["name"] = self.name + function["function"]["description"] = self.spec.description + function["function"]["parameters"] = {"type": "object", "properties": {}} + for arg in self.spec.args: + function["function"]["parameters"]["properties"][arg.name] = { + "type": map_type(arg.type), + "description": arg.description, + } + if arg.required: + required_params.append(arg.name) + function["function"]["parameters"]["required"] = required_params + + return function + class PluginRegistry(ComponentRegistry[PluginEntry]): def __init__( diff --git a/taskweaver/misc/example.py b/taskweaver/misc/example.py index c4f76672..b006c5af 100644 --- a/taskweaver/misc/example.py +++ b/taskweaver/misc/example.py @@ -5,20 +5,16 @@ from taskweaver.memory.conversation import Conversation -def load_examples(folder: str, plugin_only: bool = False) -> List[Conversation]: +def load_examples(folder: str) -> List[Conversation]: """ Load all the examples from a folder. Args: folder: the folder path. - plugin_only: whether to load only the plugin examples. """ example_file_list: List[str] = glob.glob(path.join(folder, "*.yaml")) example_conv_pool: List[Conversation] = [] for yaml_path in example_file_list: conversation = Conversation.from_yaml(yaml_path) - if plugin_only and conversation.plugin_only: - example_conv_pool.append(conversation) - elif not plugin_only and not conversation.plugin_only: - example_conv_pool.append(conversation) + example_conv_pool.append(conversation) return example_conv_pool diff --git a/taskweaver/planner/planner.py b/taskweaver/planner/planner.py index f09c26a4..80353ce2 100644 --- a/taskweaver/planner/planner.py +++ b/taskweaver/planner/planner.py @@ -68,11 +68,15 @@ def __init__( llm_api: LLMApi, plugin_registry: PluginRegistry, round_compressor: Optional[RoundCompressor] = None, + plugin_only: bool = False, ): self.config = config self.logger = logger self.llm_api = llm_api - self.plugin_registry = plugin_registry + if plugin_only: + self.available_plugins = [p for p in plugin_registry.get_list() if p.plugin_only is True] + else: + self.available_plugins = plugin_registry.get_list() self.planner_post_translator = PostTranslator(logger) @@ -80,12 +84,12 @@ def __init__( if self.config.use_example: self.examples = self.get_examples() - if len(self.plugin_registry.get_list()) == 0: + if len(self.available_plugins) == 0: self.logger.warning("No plugin is loaded for Planner.") self.plugin_description = "No plugin functions loaded." else: self.plugin_description = "\t" + "\n\t".join( - [f"- {plugin.name}: " + f"{plugin.spec.description}" for plugin in self.plugin_registry.get_list()], + [f"- {plugin.name}: " + f"{plugin.spec.description}" for plugin in self.available_plugins], ) self.instruction_template = self.prompt_data["instruction_template"] self.code_interpreter_introduction = self.prompt_data["code_interpreter_introduction"].format( diff --git a/taskweaver/planner/planner_prompt.yaml b/taskweaver/planner/planner_prompt.yaml index f3a9d7f0..a3bfbd37 100644 --- a/taskweaver/planner/planner_prompt.yaml +++ b/taskweaver/planner/planner_prompt.yaml @@ -31,6 +31,7 @@ instruction_template: |- - If User has made any changes to the environment, Planner should inform CodeInterpreter accordingly. - Planner can ignore the permission or data access issues because CodeInterpreter can handle this kind of problem. - Planner must include 2 parts: description of the User's request and the current step that the Planner is executing. + - Planner must not ask CodeInterpreter to install any packages unless the User explicitly requests to do so. ## Planner's response format - Planner must strictly format the response into the following JSON object: diff --git a/taskweaver/session/session.py b/taskweaver/session/session.py index 87c6c927..d8c0f3ca 100644 --- a/taskweaver/session/session.py +++ b/taskweaver/session/session.py @@ -4,12 +4,12 @@ from injector import Injector, inject -from taskweaver.code_interpreter import CodeInterpreter +from taskweaver.code_interpreter import CodeInterpreter, CodeInterpreterPluginOnly from taskweaver.code_interpreter.code_executor import CodeExecutor from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger from taskweaver.memory import Memory, Post, Round -from taskweaver.planner.planner import Planner, PlannerConfig +from taskweaver.planner.planner import Planner from taskweaver.workspace.workspace import Workspace @@ -19,6 +19,7 @@ def _configure(self) -> None: self.code_interpreter_only = self._get_bool("code_interpreter_only", False) self.max_internal_chat_round_num = self._get_int("max_internal_chat_round_num", 10) + self.plugin_only_mode = self._get_bool("plugin_only_mode", False) class Session: @@ -47,10 +48,12 @@ def __init__( self.session_var: Dict[str, str] = {} - # self.plugins = get_plugin_registry() - - self.planner_config = self.session_injector.get(PlannerConfig) - self.planner = self.session_injector.get(Planner) + self.planner = self.session_injector.create_object( + Planner, + { + "plugin_only": self.config.plugin_only_mode, + }, + ) self.code_executor = self.session_injector.create_object( CodeExecutor, { @@ -60,7 +63,10 @@ def __init__( }, ) self.session_injector.binder.bind(CodeExecutor, self.code_executor) - self.code_interpreter = self.session_injector.get(CodeInterpreter) + if self.config.plugin_only_mode: + self.code_interpreter = self.session_injector.get(CodeInterpreterPluginOnly) + else: + self.code_interpreter = self.session_injector.get(CodeInterpreter) self.max_internal_chat_round_num = self.config.max_internal_chat_round_num self.internal_chat_num = 0 diff --git a/tests/unit_tests/data/plugins/klarna_search.yaml b/tests/unit_tests/data/plugins/klarna_search.yaml index 18907092..c7b86344 100644 --- a/tests/unit_tests/data/plugins/klarna_search.yaml +++ b/tests/unit_tests/data/plugins/klarna_search.yaml @@ -1,6 +1,7 @@ name: klarna_search enabled: true required: false +plugin_only: true description: >- Search and compare prices from thousands of online shops. Only available in the US. diff --git a/tests/unit_tests/data/prompts/generator_plugin_only.yaml b/tests/unit_tests/data/prompts/generator_plugin_only.yaml new file mode 100644 index 00000000..4f69f6cc --- /dev/null +++ b/tests/unit_tests/data/prompts/generator_plugin_only.yaml @@ -0,0 +1,4 @@ +version: 0.1 +content: |- + {ROLE_NAME} can understand the user request and leverage pre-defined tools to complete tasks. + diff --git a/tests/unit_tests/test_code_generator.py b/tests/unit_tests/test_code_generator.py index 9ed57806..c88bb931 100644 --- a/tests/unit_tests/test_code_generator.py +++ b/tests/unit_tests/test_code_generator.py @@ -2,8 +2,10 @@ from injector import Injector +from taskweaver.code_interpreter.code_generator.code_generator_plugin_only import _compose_prompt from taskweaver.config.config_mgt import AppConfigSource from taskweaver.logging import LoggingModule +from taskweaver.memory.attachment import AttachmentType from taskweaver.memory.plugin import PluginModule @@ -48,21 +50,21 @@ def test_compose_prompt(): ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees the user wants generate a DataFrame.", ), ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) @@ -86,20 +88,20 @@ def test_compose_prompt(): ) post4.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} understands the user wants to find the data range for the DataFrame.", ), ) post4.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} will generate code to calculate the data range of the 'VALUE' column since it is the " "only numeric column.", ), ) post4.add_attachment( Attachment.create( - "python", + AttachmentType.python, ( "min_value = df['VALUE'].min()\n" "max_value = df['VALUE'].max()\n" @@ -112,10 +114,10 @@ def test_compose_prompt(): ), ), ) - post4.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post4.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post4.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "The minimum value in the 'VALUE' column is 0.05;The " "maximum value in the 'VALUE' column is 0.99;The " "data range for the 'VALUE' column is 0.94", @@ -246,21 +248,21 @@ def test_compose_prompt_with_plugin(): ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees the user wants generate a DataFrame.", ), ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) @@ -292,95 +294,79 @@ def test_compose_prompt_with_plugin_only(): config={ "app_dir": os.path.dirname(os.path.abspath(__file__)), "llm.api_key": "test_key", # pragma: allowlist secret - "code_generator.prompt_compression": True, + "code_generator.prompt_compression": False, "code_generator.prompt_file_path": os.path.join( os.path.dirname(os.path.abspath(__file__)), - "data/prompts/generator_prompt.yaml", + "data/prompts/generator_plugin_only.yaml", ), "plugin.base_path": os.path.join( os.path.dirname(os.path.abspath(__file__)), "data/plugins", ), - "code_generator.example_base_path": os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "data/examples/codeinterpreter_examples", - ), }, ) app_injector.binder.bind(AppConfigSource, to=app_config) - from taskweaver.code_interpreter.code_generator import CodeGenerator + from taskweaver.code_interpreter.code_generator import CodeGeneratorPluginOnly from taskweaver.memory import Attachment, Memory, Post, Round - code_generator = app_injector.get(CodeGenerator) - - code_generator.configure_verification( - code_verification_on=True, - plugin_only=True, - allowed_modules=[], - ) - code_generator.configure_verification(code_verification_on=True, plugin_only=True) + code_generator = app_injector.get(CodeGeneratorPluginOnly) - code1 = ( - "df = pd.DataFrame(np.random.rand(10, 2), columns=['DATE', 'VALUE'])\n" - 'descriptions = [("sample_code_description", "Sample code has been generated to get a dataframe `df` \n' - "with 10 rows and 2 columns: 'DATE' and 'VALUE'\")]" - ) + code1 = "r0 = klarna_search('iphone')\n" "r0" post1 = Post.create( - message="create a dataframe", + message="find iphones on sale", send_from="Planner", send_to="CodeInterpreter", attachment_list=[], ) post2 = Post.create( - message="A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", + message="The iphone 15 pro is on sale.", send_from="CodeInterpreter", send_to="Planner", attachment_list=[], ) post2.add_attachment( Attachment.create( - "thought", - "{ROLE_NAME} sees the user wants generate a DataFrame.", + AttachmentType.thought, + "{ROLE_NAME} sees the user wants to find iphones on sale.", ), ) post2.add_attachment( Attachment.create( - "thought", - "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", + AttachmentType.thought, + "{ROLE_NAME} can use the `klarna_search` function to find iphones on sale.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) - round1 = Round.create(user_query="hello", id="round-1") + round1 = Round.create(user_query="find iphones on sale", id="round-1") round1.add_post(post1) round1.add_post(post2) memory = Memory(session_id="session-1") memory.conversation.add_round(round1) - messages = code_generator.compose_prompt( + messages, functions = _compose_prompt( + system_instructions=code_generator.instruction_template.format( + ROLE_NAME=code_generator.role_name, + ), rounds=memory.conversation.rounds, - plugins=code_generator.get_plugin_pool(), + plugin_pool=code_generator.plugin_pool, ) - assert "read_csv" in messages[1]["content"] - assert "write_csv" in messages[1]["content"] - assert "This is the feedback" in messages[3]["content"] - assert "Execution" in messages[3]["content"] - assert "Verification" in messages[3]["content"] - - assert "sql_pull_data" in messages[4]["content"] - assert "anomaly_detection" in messages[4]["content"] - assert "klarna_search" in messages[4]["content"] - assert "paper_summary" in messages[4]["content"] + assert len(functions) == 1 + assert functions[0]["function"]["name"] == "klarna_search" + assert messages[1]["role"] == "user" + assert messages[1]["content"] == "find iphones on sale" + assert messages[2]["role"] == "assistant" + assert messages[2]["content"] == "The iphone 15 pro is on sale." def test_compose_prompt_with_not_plugin_only(): @@ -432,21 +418,21 @@ def test_compose_prompt_with_not_plugin_only(): ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees the user wants generate a DataFrame.", ), ) post2.add_attachment( Attachment.create( - "thought", + AttachmentType.thought, "{ROLE_NAME} sees all required Python libs have been imported, so will not generate import codes.", ), ) - post2.add_attachment(Attachment.create("python", code1)) - post2.add_attachment(Attachment.create("execution_status", "SUCCESS")) + post2.add_attachment(Attachment.create(AttachmentType.python, code1)) + post2.add_attachment(Attachment.create(AttachmentType.execution_status, "SUCCESS")) post2.add_attachment( Attachment.create( - "execution_result", + AttachmentType.execution_result, "A dataframe `df` with 10 rows and 2 columns: 'DATE' and 'VALUE' has been generated.", ), ) @@ -463,17 +449,11 @@ def test_compose_prompt_with_not_plugin_only(): plugins=code_generator.get_plugin_pool(), ) - assert "read_csv" not in messages[1]["content"] - assert "write_csv" not in messages[1]["content"] - assert "sql_pull_data" not in messages[1]["content"] - assert "anomaly_detection" not in messages[1]["content"] - assert "klarna_search" not in messages[1]["content"] - assert "paper_summary" not in messages[1]["content"] - - assert "sql_pull_data" in messages[13]["content"] - assert "anomaly_detection" in messages[13]["content"] - assert "klarna_search" in messages[13]["content"] - assert "paper_summary" in messages[13]["content"] + assert len(code_generator.plugin_pool) == 4 + assert "anomaly_detection" in messages[16]["content"] + assert "klarna_search" in messages[16]["content"] + assert "paper_summary" in messages[16]["content"] + assert "sql_pull_data" in messages[16]["content"] def test_code_correction_prompt(): diff --git a/tests/unit_tests/test_function_calling.py b/tests/unit_tests/test_function_calling.py new file mode 100644 index 00000000..a8701f17 --- /dev/null +++ b/tests/unit_tests/test_function_calling.py @@ -0,0 +1,66 @@ +from taskweaver.memory.plugin import PluginEntry, PluginParameter, PluginSpec + + +def test_function_formatting(): + plugin = PluginEntry( + name="test", + impl="test", + spec=PluginSpec( + name="test", + description="test", + args=[ + PluginParameter( + name="arg1", + type="string", + description="arg1", + required=True, + ), + PluginParameter( + name="arg2", + type="integer", + description="arg2", + required=False, + ), + PluginParameter( + name="arg3", + type="float", + description="arg3", + required=False, + ), + PluginParameter( + name="arg4", + type="boolean", + description="arg4", + required=False, + ), + PluginParameter( + name="arg5", + type="none", + description="arg5", + required=False, + ), + ], + ), + config={"test_key": "test_val"}, + required=False, + enabled=True, + plugin_only=True, + ) + assert plugin.format_function_calling() == { + "type": "function", + "function": { + "name": "test", + "description": "test", + "parameters": { + "type": "object", + "properties": { + "arg1": {"type": "string", "description": "arg1"}, + "arg2": {"type": "integer", "description": "arg2"}, + "arg3": {"type": "number", "description": "arg3"}, + "arg4": {"type": "boolean", "description": "arg4"}, + "arg5": {"type": "null", "description": "arg5"}, + }, + "required": ["arg1"], + }, + }, + }