Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Liqun/plugin only #91

Merged
merged 18 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
enabled: True
plugin_only: False
rounds:
- user_query: hello
state: finished
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
enabled: True
plugin_only: False
rounds:
- user_query: read file /abc/def.txt
state: finished
Expand Down
70 changes: 0 additions & 70 deletions project/codeinterpreter_examples/example3-codeinterpreter.yaml

This file was deleted.

1 change: 1 addition & 0 deletions project/plugins/anomaly_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ required: false
description: >-
anomaly_detection function identifies anomalies from an input DataFrame of
time series. It will add a new column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" otherwise.
For example, result_df, description = anomaly_detection(df, "datetime", "value").

parameters:
- name: df
Expand Down
15 changes: 15 additions & 0 deletions project/plugins/ascii_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from taskweaver.plugin import Plugin, register_plugin


@register_plugin
class AsciiRenderPlugin(Plugin):
def __call__(self, text: str):
try:
import pyfiglet
except ImportError:
raise ImportError("Please install pyfiglet first.")

ASCII_art_1 = pyfiglet.figlet_format(text, font="isometric1")
result = ASCII_art_1

return result
20 changes: 20 additions & 0 deletions project/plugins/ascii_render.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: ascii_render
enabled: true
required: true
plugin_only: true
description: >-
This plugin renders the input text into ASCII art form. The input should be a string and the output is also a string in ASCII art.
For example, result = ascii_render("Hello World!").

parameters:
- name: text
type: str
required: true
description: >-
This is the input text to be rendered into ASCII art form.

returns:
- name: result
type: str
description: >-
The rendered text in ASCII art.
2 changes: 2 additions & 0 deletions project/plugins/klarna_search.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
name: klarna_search
enabled: true
required: false
plugin_only: true
description: >-
Search and compare prices from thousands of online shops. Only available in the US.
This plugin only takes user requests when searching for merchandise.
If not clear, confirm with the user if they want to search for merchandise from Klarna.
For example, result, description = klarna_search("laptop", 10, 1000, 2000).

parameters:
- name: query
Expand Down
1 change: 1 addition & 0 deletions project/plugins/paper_summary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ description: >-
summarize_paper function iteratively summarizes a given paper page by page,
highlighting the key points, including the problem, main idea, contributions,
experiments, results, and conclusions.
For example, result, description = summarize_paper("paper.pdf").

parameters:
- name: paper_file_path
Expand Down
1 change: 1 addition & 0 deletions project/plugins/sql_pull_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ description: >-
This plugin takes user requests when obtaining data from database is explicitly mentioned.
Otherwise, confirm with the user if they want to pull data from this database.
The data from this database can only used for anomaly detection.
For example, df, description = sql_pull_data("pull data from time_series table").

parameters:
- name: query
Expand Down
13 changes: 13 additions & 0 deletions project/plugins/tell_joke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from taskweaver.plugin import Plugin, register_plugin


@register_plugin
class TellJoke(Plugin):
def __call__(self, lan: str = "en"):
try:
import pyjokes
except ImportError:
raise ImportError("Please install pyjokes first.")

# Define the API endpoint and parameters
return pyjokes.get_joke(language=lan, category="neutral")
18 changes: 18 additions & 0 deletions project/plugins/tell_joke.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: tell_joke
enabled: true
required: false
plugin_only: true
description: >-
Call this plugin to tell a joke. For example, result = tell_joke("en").

parameters:
- name: lan
type: str
required: false
description: the language of the joke. Default is English. It can be en, de, es, it, gl, eu.


returns:
- name: joke
type: str
description: the joke.
3 changes: 2 additions & 1 deletion taskweaver/code_interpreter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .code_interpreter import CodeInterpreter, CodeInterpreterConfig
from .code_interpreter import CodeInterpreter
from .code_interpreter_plugin_only import CodeInterpreterPluginOnly
1 change: 1 addition & 0 deletions taskweaver/code_interpreter/code_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .code_generator import CodeGenerator, CodeGeneratorConfig, format_code_revision_message
from .code_generator_plugin_only import CodeGeneratorPluginOnly
19 changes: 2 additions & 17 deletions taskweaver/code_interpreter/code_generator/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(
self.examples = None
self.code_verification_on: bool = False
self.allowed_modules: List[str] = []
self.plugin_only: bool = False

self.instruction = self.instruction_template.format(
ROLE_NAME=self.role_name,
Expand All @@ -98,10 +97,8 @@ def __init__(
def configure_verification(
self,
code_verification_on: bool,
plugin_only: bool,
allowed_modules: Optional[List[str]] = None,
):
self.plugin_only = plugin_only
self.allowed_modules = allowed_modules if allowed_modules is not None else []
self.code_verification_on = code_verification_on

Expand All @@ -113,23 +110,13 @@ def compose_verification_requirements(
if not self.code_verification_on:
return ""

if self.plugin_only:
requirements.append(
f"- {self.role_name} should only use the following plugins and"
+ " Python built-in functions to complete the task: "
+ ", ".join([f"{plugin.name}" for plugin in plugin_list]),
)
requirements.append(
f"- {self.role_name} cannot define new functions or plugins.",
)

if len(self.allowed_modules) > 0:
requirements.append(
f"- {self.role_name} can only import the following Python modules: "
+ ", ".join([f"{module}" for module in self.allowed_modules]),
)

if len(self.allowed_modules) == 0 and self.plugin_only:
if len(self.allowed_modules) == 0:
requirements.append(f"- {self.role_name} cannot import any Python modules.")
return "\n".join(requirements)

Expand All @@ -141,7 +128,7 @@ def compose_prompt(
chat_history = [format_chat_message(role="system", message=self.instruction)]

if self.examples is None:
self.examples = self.load_examples(plugin_only=self.plugin_only)
self.examples = self.load_examples()
for i, example in enumerate(self.examples):
chat_history.extend(
self.compose_conversation(example.rounds, example.plugins, add_requirements=False),
Expand Down Expand Up @@ -366,12 +353,10 @@ def format_plugins(

def load_examples(
self,
plugin_only: bool,
) -> List[Conversation]:
if self.config.load_example:
return load_examples(
folder=self.config.example_base_path,
plugin_only=plugin_only,
)
return []

Expand Down
Loading