From 1ea1d7bcd8026ae4da8855c54552573b8fdc3212 Mon Sep 17 00:00:00 2001 From: Jochen Schultz Date: Wed, 11 Sep 2024 00:08:51 +0000 Subject: [PATCH 1/4] fixed some linter problems --- src/core/database.py | 1 - src/core/openai_api.py | 2 -- tests/test_compose_prompt.py | 1 - tests/test_database.py | 1 - tests/test_log_complete_prompt.py | 2 +- tests/test_parallel_complete_prompt.py | 2 +- tests/test_prompt_generation.py | 1 - tests/test_token_counter.py | 1 - 8 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/core/database.py b/src/core/database.py index f5019c1..b58fc44 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -1,5 +1,4 @@ import sqlite3 -from datetime import datetime class Database: diff --git a/src/core/openai_api.py b/src/core/openai_api.py index cd0b462..b3b75c7 100644 --- a/src/core/openai_api.py +++ b/src/core/openai_api.py @@ -1,7 +1,5 @@ import openai import json -import sqlite3 -from datetime import datetime from .token_counter import TokenCounter from .database import Database diff --git a/tests/test_compose_prompt.py b/tests/test_compose_prompt.py index 307edef..deaa412 100644 --- a/tests/test_compose_prompt.py +++ b/tests/test_compose_prompt.py @@ -1,4 +1,3 @@ -import pytest from core.compose_prompt import compose_prompt diff --git a/tests/test_database.py b/tests/test_database.py index 9ace82a..93a30c8 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,4 +1,3 @@ -import pytest from unittest import mock from core.database import Database diff --git a/tests/test_log_complete_prompt.py b/tests/test_log_complete_prompt.py index 8d2719e..889f716 100644 --- a/tests/test_log_complete_prompt.py +++ b/tests/test_log_complete_prompt.py @@ -23,4 +23,4 @@ async def mock_complete_prompt(*args, **kwargs): result = await log_prompt.complete_prompt() # Assert the completion result - assert result["completed"] == True + assert result["completed"] is True diff --git a/tests/test_parallel_complete_prompt.py b/tests/test_parallel_complete_prompt.py index 6442edf..6fc2ba4 100644 --- a/tests/test_parallel_complete_prompt.py +++ b/tests/test_parallel_complete_prompt.py @@ -10,4 +10,4 @@ async def mock_complete_prompt(*args, **kwargs): parallel_prompt = ParallelCompletePrompt(mock_complete_prompt) result = await parallel_prompt.complete_prompt() - assert result["completed"] == True + assert result["completed"] is True diff --git a/tests/test_prompt_generation.py b/tests/test_prompt_generation.py index d9d449d..8fb84d1 100644 --- a/tests/test_prompt_generation.py +++ b/tests/test_prompt_generation.py @@ -1,4 +1,3 @@ -import pytest from core.prompt_generation import PromptGenerator diff --git a/tests/test_token_counter.py b/tests/test_token_counter.py index 6668020..db1f6b7 100644 --- a/tests/test_token_counter.py +++ b/tests/test_token_counter.py @@ -1,4 +1,3 @@ -import pytest from core.token_counter import TokenCounter From 14979227315313ba0521949a5be4da4c69b963bb Mon Sep 17 00:00:00 2001 From: Jochen Schultz Date: Wed, 11 Sep 2024 08:57:23 +0000 Subject: [PATCH 2/4] remove database functionality to handle rate limits.. user need to implement that by themself - also logging is only additional to output of errors --- install.py | 71 +----------------------------------------- src/core/database.py | 61 ------------------------------------ src/core/logging.py | 2 ++ src/core/openai_api.py | 11 ------- tests/test_database.py | 58 ---------------------------------- 5 files changed, 3 insertions(+), 200 deletions(-) delete mode 100644 src/core/database.py delete mode 100644 tests/test_database.py diff --git a/install.py b/install.py index aae27fa..a2a9dc2 100644 --- a/install.py +++ b/install.py @@ -1,4 +1,3 @@ -import sqlite3 import os import json import argparse @@ -9,22 +8,16 @@ def create_settings(ci_mode=False): if ci_mode: # Use default values for CI api_key = "sk-test-key" - tier = "tier-4" log_path = './var/logs/error.log' - database_path = './var/data/agents.db' else: # Prompt user for settings api_key = input('Enter your OpenAI API key: ') - tier = input('Enter your OpenAI tier level (e.g., tier-1): ') log_path = input('Enter the log directory path [default: ./var/logs/error.log]: ') or './var/logs/error.log' - database_path = input('Enter the database path [default: ./var/data/agents.db]: ') or './var/data/agents.db' # Save settings to JSON file settings = { 'openai_api_key': api_key, - 'tier': tier, - 'log_path': log_path, - 'database_path': database_path + 'log_path': log_path } os.makedirs('./config', exist_ok=True) with open('./config/settings.json', 'w') as f: @@ -32,71 +25,9 @@ def create_settings(ci_mode=False): print('Settings saved to config/settings.json') -# Function to create the database structure - -def create_database(db_path): - os.makedirs(os.path.dirname(db_path), exist_ok=True) - conn = sqlite3.connect(db_path) - c = conn.cursor() - - # Create tables - c.execute('''CREATE TABLE IF NOT EXISTS models ( - id INTEGER PRIMARY KEY, - model TEXT NOT NULL, - price_per_prompt_token REAL NOT NULL, - price_per_completion_token REAL NOT NULL)''') - - c.execute('''CREATE TABLE IF NOT EXISTS rate_limits ( - id INTEGER PRIMARY KEY, - model TEXT NOT NULL, - tier TEXT NOT NULL, - rpm_limit INTEGER NOT NULL, - tpm_limit INTEGER NOT NULL, - rpd_limit INTEGER NOT NULL)''') - - c.execute('''CREATE TABLE IF NOT EXISTS api_usage ( - id INTEGER PRIMARY KEY, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, - session_id TEXT NOT NULL, - model TEXT NOT NULL, - prompt_tokens INTEGER NOT NULL, - completion_tokens INTEGER NOT NULL, - total_tokens INTEGER NOT NULL, - price_per_prompt_token REAL NOT NULL, - price_per_completion_token REAL NOT NULL, - total_cost REAL NOT NULL)''') - - c.execute('''CREATE TABLE IF NOT EXISTS chat_sessions ( - id INTEGER PRIMARY KEY, - session_id TEXT NOT NULL, - start_time DATETIME DEFAULT CURRENT_TIMESTAMP, - end_time DATETIME)''') - - c.execute('''CREATE TABLE IF NOT EXISTS chats ( - id INTEGER PRIMARY KEY, - session_id TEXT NOT NULL, - chat_id TEXT NOT NULL, - message TEXT NOT NULL, - role TEXT NOT NULL, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') - - # Insert default models and rate limits - c.execute("INSERT INTO models (model, price_per_prompt_token, price_per_completion_token) VALUES ('gpt-4o-mini', 0.03, 0.06)") - c.execute("INSERT INTO rate_limits (model, tier, rpm_limit, tpm_limit, rpd_limit) VALUES ('gpt-4o-mini', 'tier-1', 60, 50000, 1000)") - - conn.commit() - conn.close() - print(f"Database created at {db_path}") - - if __name__ == '__main__': parser = argparse.ArgumentParser(description='Setup script for installation.') parser.add_argument('--ci', action='store_true', help='Use default values for CI without prompting.') args = parser.parse_args() create_settings(ci_mode=args.ci) - - with open('./config/settings.json', 'r') as f: - settings = json.load(f) - - create_database(settings['database_path']) diff --git a/src/core/database.py b/src/core/database.py deleted file mode 100644 index b58fc44..0000000 --- a/src/core/database.py +++ /dev/null @@ -1,61 +0,0 @@ -import sqlite3 - - -class Database: - def __init__(self, db_path): - self.db_path = db_path - - def connect(self): - return sqlite3.connect(self.db_path) - - def check_rate_limits(self, model): - conn = self.connect() - c = conn.cursor() - - # Check current API usage (RPM, TPM, RPD) - c.execute( - "SELECT SUM(total_tokens) FROM api_usage WHERE model = ? AND timestamp >= datetime('now', '-1 minute')", - (model,), - ) - tokens_last_minute = c.fetchone()[0] or 0 - - c.execute("SELECT tpm_limit FROM rate_limits WHERE model = ?", (model,)) - tpm_limit = c.fetchone()[0] - - conn.close() - return tokens_last_minute < tpm_limit - - def log_api_usage( - self, session_id, model, prompt_tokens, completion_tokens, total_tokens - ): - conn = self.connect() - c = conn.cursor() - - # Fetch token prices - c.execute( - "SELECT price_per_prompt_token, price_per_completion_token FROM models WHERE model = ?", - (model,), - ) - prices = c.fetchone() - prompt_price = prices[0] - completion_price = prices[1] - total_cost = (prompt_tokens * prompt_price) + ( - completion_tokens * completion_price - ) - - c.execute( - "INSERT INTO api_usage (session_id, model, prompt_tokens, completion_tokens, total_tokens, price_per_prompt_token, price_per_completion_token, total_cost) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - ( - session_id, - model, - prompt_tokens, - completion_tokens, - total_tokens, - prompt_price, - completion_price, - total_cost, - ), - ) - - conn.commit() - conn.close() diff --git a/src/core/logging.py b/src/core/logging.py index 229cb9a..ce7a5c0 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -25,6 +25,8 @@ def load_settings(self, settings_path): def info(self, message): logging.info(message) + print(message) def error(self, message): logging.error(message) + print(f"ERROR: {message}") diff --git a/src/core/openai_api.py b/src/core/openai_api.py index b3b75c7..8176181 100644 --- a/src/core/openai_api.py +++ b/src/core/openai_api.py @@ -1,7 +1,6 @@ import openai import json from .token_counter import TokenCounter -from .database import Database class OpenAIClient: @@ -9,7 +8,6 @@ def __init__(self, settings_path="../config/settings.json"): settings = self.load_settings(settings_path) self.api_key = settings["openai_api_key"] openai.api_key = self.api_key - self.db = Database(settings["database_path"]) self.token_counter = TokenCounter() def load_settings(self, settings_path): @@ -22,10 +20,6 @@ def load_settings(self, settings_path): raise Exception(f"Missing key in settings: {e}") def complete_chat(self, messages, model="gpt-4o-mini", max_tokens=1500): - # Check rate limits - if not self.db.check_rate_limits(model): - raise Exception(f"Rate limit exceeded for model {model}") - prompt_tokens = self.token_counter.count_tokens(messages) try: @@ -38,11 +32,6 @@ def complete_chat(self, messages, model="gpt-4o-mini", max_tokens=1500): ) total_tokens = prompt_tokens + completion_tokens - # Log token usage and cost in the database - self.db.log_api_usage( - "session-1", model, prompt_tokens, completion_tokens, total_tokens - ) - return response.choices[0].message["content"] except openai.error.OpenAIError as e: raise Exception(f"Error with OpenAI API: {str(e)}") diff --git a/tests/test_database.py b/tests/test_database.py deleted file mode 100644 index 93a30c8..0000000 --- a/tests/test_database.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest import mock -from core.database import Database - - -def test_check_rate_limits(): - # Mock the database connection and cursor - mock_conn = mock.MagicMock() - mock_cursor = mock_conn.cursor.return_value - - # Mock the cursor's return values for token limits and usage - mock_cursor.fetchone.side_effect = [ - (100,), - (50000,), - ] # Return tuples for token usage and limit - - # Patch the connect method to return the mock connection - with mock.patch("sqlite3.connect", return_value=mock_conn): - db = Database(db_path=":memory:") # Use an in-memory database for testing - - # Call the method and assert it returns True (since usage < limit) - assert db.check_rate_limits("gpt-4o-mini") == True - - -def test_log_and_delete_api_usage(): - # Mock the database connection and cursor - mock_conn = mock.MagicMock() - mock_cursor = mock_conn.cursor.return_value - - # Patch the connect method to return the mock connection - with mock.patch("sqlite3.connect", return_value=mock_conn): - db = Database(db_path=":memory:") - - # Log the API usage - db.log_api_usage("session-1", "gpt-4o-mini", 100, 200, 300) - - # Assert the INSERT query was executed - mock_cursor.execute.assert_any_call( - "INSERT INTO api_usage (session_id, model, prompt_tokens, completion_tokens, total_tokens, price_per_prompt_token, price_per_completion_token, total_cost) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - ("session-1", "gpt-4o-mini", 100, 200, 300, mock.ANY, mock.ANY, mock.ANY), - ) - - # Mock the SELECT query to return the inserted values - mock_cursor.fetchone.side_effect = [("session-1", "gpt-4o-mini", 100, 200, 300)] - - # Simulate fetching the logged API usage - db.connect().cursor().execute( - "SELECT * FROM api_usage WHERE session_id = ?", ("session-1",) - ) - - # Now simulate deleting the entry - db.connect().cursor().execute( - "DELETE FROM api_usage WHERE session_id = ?", ("session-1",) - ) - - # Assert the DELETE query was executed - mock_cursor.execute.assert_any_call( - "DELETE FROM api_usage WHERE session_id = ?", ("session-1",) - ) From af9633f25d1008f4e4e86455a8fc99751c3617c7 Mon Sep 17 00:00:00 2001 From: Jochen Schultz Date: Wed, 11 Sep 2024 12:40:17 +0000 Subject: [PATCH 3/4] add first example on how to use agentm-py to sort a list and fix the core and installer to make it actually work --- examples/sort_list_example.py | 25 +++++++++ install.py | 5 +- requirements.txt | 1 + src/core/logging.py | 47 ++++++++++------ src/core/openai_api.py | 49 ++++++----------- src/core/sort_list_agent.py | 100 ++++++++++++++++++++++++++++++++++ 6 files changed, 179 insertions(+), 48 deletions(-) create mode 100644 examples/sort_list_example.py create mode 100644 src/core/sort_list_agent.py diff --git a/examples/sort_list_example.py b/examples/sort_list_example.py new file mode 100644 index 0000000..62bcee0 --- /dev/null +++ b/examples/sort_list_example.py @@ -0,0 +1,25 @@ +import asyncio +from core.sort_list_agent import SortListAgent + +async def run_sort_list_example(): + # Sample input list + items_to_sort = [ + "Apple", "Orange", "Banana", "Grape", "Pineapple" + ] + + # Define a goal for sorting (this will influence the OpenAI model's decision) + goal = "Sort the fruits alphabetically." + + # Create the sorting agent + agent = SortListAgent(goal=goal, list_to_sort=items_to_sort, log_explanations=True) + + # Execute the sorting process + sorted_list = await agent.sort() + + # Output the result + print("Original list:", items_to_sort) + print("Sorted list:", sorted_list) + +# Run the example +if __name__ == "__main__": + asyncio.run(run_sort_list_example()) \ No newline at end of file diff --git a/install.py b/install.py index a2a9dc2..478bcb9 100644 --- a/install.py +++ b/install.py @@ -9,15 +9,18 @@ def create_settings(ci_mode=False): # Use default values for CI api_key = "sk-test-key" log_path = './var/logs/error.log' + debug = False else: # Prompt user for settings api_key = input('Enter your OpenAI API key: ') log_path = input('Enter the log directory path [default: ./var/logs/error.log]: ') or './var/logs/error.log' + debug = input('Enable debug mode? [y/n]: ').lower() == 'y' # Save settings to JSON file settings = { 'openai_api_key': api_key, - 'log_path': log_path + 'log_path': log_path, + 'debug': debug } os.makedirs('./config', exist_ok=True) with open('./config/settings.json', 'w') as f: diff --git a/requirements.txt b/requirements.txt index d86ccb4..93cbcbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ flake8 tiktoken anyio trio +openai diff --git a/src/core/logging.py b/src/core/logging.py index ce7a5c0..e022732 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -1,32 +1,47 @@ import logging +import http.client import json import os - class Logger: def __init__(self, settings_path="../config/settings.json"): self.settings = self.load_settings(settings_path) self.log_path = self.settings["log_path"] os.makedirs(os.path.dirname(self.log_path), exist_ok=True) - logging.basicConfig( - filename=self.log_path, - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - ) + + # Create a logger instance + self.logger = logging.getLogger("AgentMLogger") + self.logger.setLevel(logging.DEBUG if self.settings.get("debug", False) else logging.INFO) + + # File handler for logging to a file + file_handler = logging.FileHandler(self.log_path) + file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + self.logger.addHandler(file_handler) + + # Console handler for output to the console + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + self.logger.addHandler(console_handler) + + # Enable HTTP-level logging if debug is enabled + if self.settings.get("debug", False): + self.enable_http_debug() def load_settings(self, settings_path): - try: - with open(settings_path, "r") as f: - return json.load(f) - except FileNotFoundError: - raise Exception(f"Settings file not found at {settings_path}") - except KeyError as e: - raise Exception(f"Missing key in settings: {e}") + if not os.path.exists(settings_path): + raise FileNotFoundError(f"Settings file not found at {settings_path}") + with open(settings_path, "r") as f: + return json.load(f) + + def enable_http_debug(self): + """Enable HTTP-level logging for API communication.""" + http.client.HTTPConnection.debuglevel = 1 + logging.getLogger("http.client").setLevel(logging.DEBUG) + logging.getLogger("http.client").propagate = True def info(self, message): - logging.info(message) - print(message) + self.logger.info(message) def error(self, message): - logging.error(message) + self.logger.error(message) print(f"ERROR: {message}") diff --git a/src/core/openai_api.py b/src/core/openai_api.py index 8176181..09b8d30 100644 --- a/src/core/openai_api.py +++ b/src/core/openai_api.py @@ -1,37 +1,24 @@ -import openai -import json -from .token_counter import TokenCounter - +from openai import OpenAI, BadRequestError +from .logging import Logger # Use the logger abstraction +import os class OpenAIClient: - def __init__(self, settings_path="../config/settings.json"): - settings = self.load_settings(settings_path) - self.api_key = settings["openai_api_key"] - openai.api_key = self.api_key - self.token_counter = TokenCounter() - - def load_settings(self, settings_path): - try: - with open(settings_path, "r") as f: - return json.load(f) - except FileNotFoundError: - raise Exception(f"Settings file not found at {settings_path}") - except KeyError as e: - raise Exception(f"Missing key in settings: {e}") + def __init__(self, settings_path=None): + if settings_path is None: + settings_path = os.path.join(os.path.dirname(__file__), '../../config/settings.json') - def complete_chat(self, messages, model="gpt-4o-mini", max_tokens=1500): - prompt_tokens = self.token_counter.count_tokens(messages) + self.logger = Logger(settings_path) + settings = self.logger.load_settings(settings_path) + self.client = OpenAI(api_key=settings["openai_api_key"]) + async def complete_chat(self, messages, model="gpt-4o-mini", max_tokens=1500): try: - response = openai.ChatCompletion.create( - model=model, messages=messages, max_tokens=max_tokens + response = self.client.chat.completions.create( + model=model, + messages=messages, + max_tokens=max_tokens ) - - completion_tokens = self.token_counter.count_tokens( - response.choices[0].message["content"] - ) - total_tokens = prompt_tokens + completion_tokens - - return response.choices[0].message["content"] - except openai.error.OpenAIError as e: - raise Exception(f"Error with OpenAI API: {str(e)}") + return response.choices[0].message.content + except BadRequestError as e: + self.logger.error(f"Error with OpenAI API: {str(e)}") + raise diff --git a/src/core/sort_list_agent.py b/src/core/sort_list_agent.py new file mode 100644 index 0000000..5aee7de --- /dev/null +++ b/src/core/sort_list_agent.py @@ -0,0 +1,100 @@ +import asyncio +from typing import List +from .openai_api import OpenAIClient + +class SortListAgent: + def __init__(self, goal: str, list_to_sort: List[str], max_tokens: int = 1000, temperature: float = 0.0, log_explanations: bool = False): + self.goal = goal + self.list = list_to_sort + self.max_tokens = max_tokens + self.temperature = temperature + self.log_explanations = log_explanations + self.openai_client = OpenAIClient() + + async def sort(self): + return await self.merge_sort(self.list) + + async def batch_compare(self, pairs): + """ + Send multiple comparison pairs to the API in one request to reduce API calls. + """ + batch_prompt = "\n".join([f"Compare {a} and {b} and return the items in the correct order as 'item1,item2'." for a, b in pairs]) + system_prompt = f"You are tasked with sorting items. Goal: {self.goal}.\nCompare the following pairs and return the correct order." + + # Log the request we're sending + self.openai_client.logger.info(f"Sending batch comparison request with prompt: {batch_prompt}") + + response = await self.openai_client.complete_chat([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": batch_prompt} + ], max_tokens=self.max_tokens) + + # Log the response we receive + self.openai_client.logger.info(f"Received response: {response}") + + comparisons = response.split("\n") # Assuming API returns comparisons in batch order + + # Check for empty response and log an error + if not comparisons: + self.openai_client.logger.error("Empty response received from API.") + + # Parse responses and filter out empty or malformed comparisons + parsed_comparisons = [] + for comparison in comparisons: + individual_comparisons = comparison.split(" ") # Split individual results + for comp in individual_comparisons: + comp = comp.strip() + if not comp: # Ignore empty results + continue + try: + first, second = comp.split(",") + if first.strip() == pairs[0][0]: + parsed_comparisons.append("BEFORE") + else: + parsed_comparisons.append("AFTER") + except ValueError: + self.openai_client.logger.info(f"Ignoring unexpected comparison result: {comp}") + + return parsed_comparisons + + async def merge_sort(self, items): + if len(items) < 2: + return items + + mid = len(items) // 2 + left_half, right_half = await asyncio.gather(self.merge_sort(items[:mid]), self.merge_sort(items[mid:])) + return await self.merge(left_half, right_half) + + async def merge(self, left, right): + result = [] + i, j = 0, 0 + comparisons_to_make = [] + + while i < len(left) and j < len(right): + comparisons_to_make.append((left[i], right[j])) + i += 1 + j += 1 + + # Batch process comparisons + comparison_results = await self.batch_compare(comparisons_to_make) + + # Safely ignore last comparison if there are no more results + if not comparison_results: + self.openai_client.logger.info("Final comparison complete.") + result.extend(left[i:]) + result.extend(right[j:]) + return result + + i = 0 + j = 0 + while i < len(left) and j < len(right) and comparison_results: + if comparison_results.pop(0) == "BEFORE": + result.append(left[i]) + i += 1 + else: + result.append(right[j]) + j += 1 + + result.extend(left[i:]) + result.extend(right[j:]) + return result From 2858ac6384a9641be95f85bcc42e51d85aba7a99 Mon Sep 17 00:00:00 2001 From: Jochen Schultz Date: Wed, 11 Sep 2024 13:54:18 +0000 Subject: [PATCH 4/4] add gile list example --- examples/filter_list_example.py | 23 ++++++++++ requirements.txt | 1 + src/core/filter_list_agent.py | 76 +++++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+) create mode 100644 examples/filter_list_example.py create mode 100644 src/core/filter_list_agent.py diff --git a/examples/filter_list_example.py b/examples/filter_list_example.py new file mode 100644 index 0000000..745ba2f --- /dev/null +++ b/examples/filter_list_example.py @@ -0,0 +1,23 @@ +import asyncio +from core.filter_list_agent import FilterListAgent + +async def run_filter_list_example(): + goal = "Remove items that are unhealthy snacks." + items_to_filter = [ + "Apple", + "Chocolate bar", + "Carrot", + "Chips", + "Orange" + ] + + agent = FilterListAgent(goal=goal, items_to_filter=items_to_filter) + filtered_results = await agent.filter() + + print("Original list:", items_to_filter) + print("Filtered results:") + for result in filtered_results: + print(result) + +if __name__ == "__main__": + asyncio.run(run_filter_list_example()) diff --git a/requirements.txt b/requirements.txt index 93cbcbd..bd5b8c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ tiktoken anyio trio openai +jsonschema diff --git a/src/core/filter_list_agent.py b/src/core/filter_list_agent.py new file mode 100644 index 0000000..ecdb522 --- /dev/null +++ b/src/core/filter_list_agent.py @@ -0,0 +1,76 @@ +import asyncio +import json +import jsonschema +from typing import List, Dict +from .openai_api import OpenAIClient + +class FilterListAgent: + def __init__(self, goal: str, items_to_filter: List[str], max_tokens: int = 500, temperature: float = 0.0): + self.goal = goal + self.items = items_to_filter + self.max_tokens = max_tokens + self.temperature = temperature + self.openai_client = OpenAIClient() + + # JSON schema for validation + schema = { + "type": "object", + "properties": { + "explanation": {"type": "string"}, + "remove_item": {"type": "boolean"} + }, + "required": ["explanation", "remove_item"] + } + + async def filter(self) -> List[Dict]: + return await self.filter_list(self.items) + + async def filter_list(self, items: List[str]) -> List[Dict]: + # System prompt with multi-shot examples to guide the model + system_prompt = ( + "You are an assistant tasked with filtering a list of items. The goal is: " + f"{self.goal}. For each item, decide if it should be removed based on whether it is a healthy snack.\n" + "Respond in the following structured format:\n\n" + "Example:\n" + "{\"explanation\": \"The apple is a healthy snack option, as it is low in calories...\",\n" + " \"remove_item\": false}\n\n" + "Example:\n" + "{\"explanation\": \"A chocolate bar is generally considered an unhealthy snack...\",\n" + " \"remove_item\": true}\n\n" + ) + + tasks = [] + for index, item in enumerate(items): + user_prompt = f"Item {index+1}: {item}. Should it be removed? Answer with explanation and 'remove_item': true/false." + tasks.append(self.filter_item(system_prompt, user_prompt)) + + # Run all tasks in parallel + results = await asyncio.gather(*tasks) + + # Show the final list of items that were kept + filtered_items = [self.items[i] for i, result in enumerate(results) if not result.get('remove_item', False)] + print("\nFinal Filtered List:", filtered_items) + + return results + + async def filter_item(self, system_prompt: str, user_prompt: str) -> Dict: + response = await self.openai_client.complete_chat([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], max_tokens=self.max_tokens) + + return await self.process_response(response, system_prompt, user_prompt) + + async def process_response(self, response: str, system_prompt: str, user_prompt: str, retry: bool = True) -> Dict: + try: + # Parse the response as JSON + result = json.loads(response) + # Validate against the schema + jsonschema.validate(instance=result, schema=self.schema) + return result + except (json.JSONDecodeError, jsonschema.ValidationError) as e: + if retry: + # Retry once if validation fails + return await self.filter_item(system_prompt, user_prompt) + else: + return {"error": f"Failed to parse response after retry: {str(e)}", "response": response, "item": user_prompt}