Skip to content

Commit

Permalink
Merge pull request #38 from debuggerone/examples
Browse files Browse the repository at this point in the history
Examples - I let the pipeline run on my fork - so the actions run through
  • Loading branch information
debuggerone authored Sep 11, 2024
2 parents c2372e4 + 2858ac6 commit 4d25309
Show file tree
Hide file tree
Showing 15 changed files with 281 additions and 254 deletions.
23 changes: 23 additions & 0 deletions examples/filter_list_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import asyncio
from core.filter_list_agent import FilterListAgent

async def run_filter_list_example():
goal = "Remove items that are unhealthy snacks."
items_to_filter = [
"Apple",
"Chocolate bar",
"Carrot",
"Chips",
"Orange"
]

agent = FilterListAgent(goal=goal, items_to_filter=items_to_filter)
filtered_results = await agent.filter()

print("Original list:", items_to_filter)
print("Filtered results:")
for result in filtered_results:
print(result)

if __name__ == "__main__":
asyncio.run(run_filter_list_example())
25 changes: 25 additions & 0 deletions examples/sort_list_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio
from core.sort_list_agent import SortListAgent

async def run_sort_list_example():
# Sample input list
items_to_sort = [
"Apple", "Orange", "Banana", "Grape", "Pineapple"
]

# Define a goal for sorting (this will influence the OpenAI model's decision)
goal = "Sort the fruits alphabetically."

# Create the sorting agent
agent = SortListAgent(goal=goal, list_to_sort=items_to_sort, log_explanations=True)

# Execute the sorting process
sorted_list = await agent.sort()

# Output the result
print("Original list:", items_to_sort)
print("Sorted list:", sorted_list)

# Run the example
if __name__ == "__main__":
asyncio.run(run_sort_list_example())
72 changes: 3 additions & 69 deletions install.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sqlite3
import os
import json
import argparse
Expand All @@ -9,94 +8,29 @@ def create_settings(ci_mode=False):
if ci_mode:
# Use default values for CI
api_key = "sk-test-key"
tier = "tier-4"
log_path = './var/logs/error.log'
database_path = './var/data/agents.db'
debug = False
else:
# Prompt user for settings
api_key = input('Enter your OpenAI API key: ')
tier = input('Enter your OpenAI tier level (e.g., tier-1): ')
log_path = input('Enter the log directory path [default: ./var/logs/error.log]: ') or './var/logs/error.log'
database_path = input('Enter the database path [default: ./var/data/agents.db]: ') or './var/data/agents.db'
debug = input('Enable debug mode? [y/n]: ').lower() == 'y'

# Save settings to JSON file
settings = {
'openai_api_key': api_key,
'tier': tier,
'log_path': log_path,
'database_path': database_path
'debug': debug
}
os.makedirs('./config', exist_ok=True)
with open('./config/settings.json', 'w') as f:
json.dump(settings, f, indent=4)
print('Settings saved to config/settings.json')


# Function to create the database structure

def create_database(db_path):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
conn = sqlite3.connect(db_path)
c = conn.cursor()

# Create tables
c.execute('''CREATE TABLE IF NOT EXISTS models (
id INTEGER PRIMARY KEY,
model TEXT NOT NULL,
price_per_prompt_token REAL NOT NULL,
price_per_completion_token REAL NOT NULL)''')

c.execute('''CREATE TABLE IF NOT EXISTS rate_limits (
id INTEGER PRIMARY KEY,
model TEXT NOT NULL,
tier TEXT NOT NULL,
rpm_limit INTEGER NOT NULL,
tpm_limit INTEGER NOT NULL,
rpd_limit INTEGER NOT NULL)''')

c.execute('''CREATE TABLE IF NOT EXISTS api_usage (
id INTEGER PRIMARY KEY,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
session_id TEXT NOT NULL,
model TEXT NOT NULL,
prompt_tokens INTEGER NOT NULL,
completion_tokens INTEGER NOT NULL,
total_tokens INTEGER NOT NULL,
price_per_prompt_token REAL NOT NULL,
price_per_completion_token REAL NOT NULL,
total_cost REAL NOT NULL)''')

c.execute('''CREATE TABLE IF NOT EXISTS chat_sessions (
id INTEGER PRIMARY KEY,
session_id TEXT NOT NULL,
start_time DATETIME DEFAULT CURRENT_TIMESTAMP,
end_time DATETIME)''')

c.execute('''CREATE TABLE IF NOT EXISTS chats (
id INTEGER PRIMARY KEY,
session_id TEXT NOT NULL,
chat_id TEXT NOT NULL,
message TEXT NOT NULL,
role TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''')

# Insert default models and rate limits
c.execute("INSERT INTO models (model, price_per_prompt_token, price_per_completion_token) VALUES ('gpt-4o-mini', 0.03, 0.06)")
c.execute("INSERT INTO rate_limits (model, tier, rpm_limit, tpm_limit, rpd_limit) VALUES ('gpt-4o-mini', 'tier-1', 60, 50000, 1000)")

conn.commit()
conn.close()
print(f"Database created at {db_path}")


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Setup script for installation.')
parser.add_argument('--ci', action='store_true', help='Use default values for CI without prompting.')
args = parser.parse_args()

create_settings(ci_mode=args.ci)

with open('./config/settings.json', 'r') as f:
settings = json.load(f)

create_database(settings['database_path'])
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ flake8
tiktoken
anyio
trio
openai
jsonschema
62 changes: 0 additions & 62 deletions src/core/database.py

This file was deleted.

76 changes: 76 additions & 0 deletions src/core/filter_list_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import asyncio
import json
import jsonschema
from typing import List, Dict
from .openai_api import OpenAIClient

class FilterListAgent:
def __init__(self, goal: str, items_to_filter: List[str], max_tokens: int = 500, temperature: float = 0.0):
self.goal = goal
self.items = items_to_filter
self.max_tokens = max_tokens
self.temperature = temperature
self.openai_client = OpenAIClient()

# JSON schema for validation
schema = {
"type": "object",
"properties": {
"explanation": {"type": "string"},
"remove_item": {"type": "boolean"}
},
"required": ["explanation", "remove_item"]
}

async def filter(self) -> List[Dict]:
return await self.filter_list(self.items)

async def filter_list(self, items: List[str]) -> List[Dict]:
# System prompt with multi-shot examples to guide the model
system_prompt = (
"You are an assistant tasked with filtering a list of items. The goal is: "
f"{self.goal}. For each item, decide if it should be removed based on whether it is a healthy snack.\n"
"Respond in the following structured format:\n\n"
"Example:\n"
"{\"explanation\": \"The apple is a healthy snack option, as it is low in calories...\",\n"
" \"remove_item\": false}\n\n"
"Example:\n"
"{\"explanation\": \"A chocolate bar is generally considered an unhealthy snack...\",\n"
" \"remove_item\": true}\n\n"
)

tasks = []
for index, item in enumerate(items):
user_prompt = f"Item {index+1}: {item}. Should it be removed? Answer with explanation and 'remove_item': true/false."
tasks.append(self.filter_item(system_prompt, user_prompt))

# Run all tasks in parallel
results = await asyncio.gather(*tasks)

# Show the final list of items that were kept
filtered_items = [self.items[i] for i, result in enumerate(results) if not result.get('remove_item', False)]
print("\nFinal Filtered List:", filtered_items)

return results

async def filter_item(self, system_prompt: str, user_prompt: str) -> Dict:
response = await self.openai_client.complete_chat([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
], max_tokens=self.max_tokens)

return await self.process_response(response, system_prompt, user_prompt)

async def process_response(self, response: str, system_prompt: str, user_prompt: str, retry: bool = True) -> Dict:
try:
# Parse the response as JSON
result = json.loads(response)
# Validate against the schema
jsonschema.validate(instance=result, schema=self.schema)
return result
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
if retry:
# Retry once if validation fails
return await self.filter_item(system_prompt, user_prompt)
else:
return {"error": f"Failed to parse response after retry: {str(e)}", "response": response, "item": user_prompt}
47 changes: 32 additions & 15 deletions src/core/logging.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,47 @@
import logging
import http.client
import json
import os


class Logger:
def __init__(self, settings_path="../config/settings.json"):
self.settings = self.load_settings(settings_path)
self.log_path = self.settings["log_path"]
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
logging.basicConfig(
filename=self.log_path,
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)

# Create a logger instance
self.logger = logging.getLogger("AgentMLogger")
self.logger.setLevel(logging.DEBUG if self.settings.get("debug", False) else logging.INFO)

# File handler for logging to a file
file_handler = logging.FileHandler(self.log_path)
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
self.logger.addHandler(file_handler)

# Console handler for output to the console
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
self.logger.addHandler(console_handler)

# Enable HTTP-level logging if debug is enabled
if self.settings.get("debug", False):
self.enable_http_debug()

def load_settings(self, settings_path):
try:
with open(settings_path, "r") as f:
return json.load(f)
except FileNotFoundError:
raise Exception(f"Settings file not found at {settings_path}")
except KeyError as e:
raise Exception(f"Missing key in settings: {e}")
if not os.path.exists(settings_path):
raise FileNotFoundError(f"Settings file not found at {settings_path}")
with open(settings_path, "r") as f:
return json.load(f)

def enable_http_debug(self):
"""Enable HTTP-level logging for API communication."""
http.client.HTTPConnection.debuglevel = 1
logging.getLogger("http.client").setLevel(logging.DEBUG)
logging.getLogger("http.client").propagate = True

def info(self, message):
logging.info(message)
self.logger.info(message)

def error(self, message):
logging.error(message)
self.logger.error(message)
print(f"ERROR: {message}")
Loading

0 comments on commit 4d25309

Please sign in to comment.