-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from debuggerone/examples
Examples - I let the pipeline run on my fork - so the actions run through
- Loading branch information
Showing
15 changed files
with
281 additions
and
254 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import asyncio | ||
from core.filter_list_agent import FilterListAgent | ||
|
||
async def run_filter_list_example(): | ||
goal = "Remove items that are unhealthy snacks." | ||
items_to_filter = [ | ||
"Apple", | ||
"Chocolate bar", | ||
"Carrot", | ||
"Chips", | ||
"Orange" | ||
] | ||
|
||
agent = FilterListAgent(goal=goal, items_to_filter=items_to_filter) | ||
filtered_results = await agent.filter() | ||
|
||
print("Original list:", items_to_filter) | ||
print("Filtered results:") | ||
for result in filtered_results: | ||
print(result) | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(run_filter_list_example()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import asyncio | ||
from core.sort_list_agent import SortListAgent | ||
|
||
async def run_sort_list_example(): | ||
# Sample input list | ||
items_to_sort = [ | ||
"Apple", "Orange", "Banana", "Grape", "Pineapple" | ||
] | ||
|
||
# Define a goal for sorting (this will influence the OpenAI model's decision) | ||
goal = "Sort the fruits alphabetically." | ||
|
||
# Create the sorting agent | ||
agent = SortListAgent(goal=goal, list_to_sort=items_to_sort, log_explanations=True) | ||
|
||
# Execute the sorting process | ||
sorted_list = await agent.sort() | ||
|
||
# Output the result | ||
print("Original list:", items_to_sort) | ||
print("Sorted list:", sorted_list) | ||
|
||
# Run the example | ||
if __name__ == "__main__": | ||
asyncio.run(run_sort_list_example()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,5 @@ flake8 | |
tiktoken | ||
anyio | ||
trio | ||
openai | ||
jsonschema |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import asyncio | ||
import json | ||
import jsonschema | ||
from typing import List, Dict | ||
from .openai_api import OpenAIClient | ||
|
||
class FilterListAgent: | ||
def __init__(self, goal: str, items_to_filter: List[str], max_tokens: int = 500, temperature: float = 0.0): | ||
self.goal = goal | ||
self.items = items_to_filter | ||
self.max_tokens = max_tokens | ||
self.temperature = temperature | ||
self.openai_client = OpenAIClient() | ||
|
||
# JSON schema for validation | ||
schema = { | ||
"type": "object", | ||
"properties": { | ||
"explanation": {"type": "string"}, | ||
"remove_item": {"type": "boolean"} | ||
}, | ||
"required": ["explanation", "remove_item"] | ||
} | ||
|
||
async def filter(self) -> List[Dict]: | ||
return await self.filter_list(self.items) | ||
|
||
async def filter_list(self, items: List[str]) -> List[Dict]: | ||
# System prompt with multi-shot examples to guide the model | ||
system_prompt = ( | ||
"You are an assistant tasked with filtering a list of items. The goal is: " | ||
f"{self.goal}. For each item, decide if it should be removed based on whether it is a healthy snack.\n" | ||
"Respond in the following structured format:\n\n" | ||
"Example:\n" | ||
"{\"explanation\": \"The apple is a healthy snack option, as it is low in calories...\",\n" | ||
" \"remove_item\": false}\n\n" | ||
"Example:\n" | ||
"{\"explanation\": \"A chocolate bar is generally considered an unhealthy snack...\",\n" | ||
" \"remove_item\": true}\n\n" | ||
) | ||
|
||
tasks = [] | ||
for index, item in enumerate(items): | ||
user_prompt = f"Item {index+1}: {item}. Should it be removed? Answer with explanation and 'remove_item': true/false." | ||
tasks.append(self.filter_item(system_prompt, user_prompt)) | ||
|
||
# Run all tasks in parallel | ||
results = await asyncio.gather(*tasks) | ||
|
||
# Show the final list of items that were kept | ||
filtered_items = [self.items[i] for i, result in enumerate(results) if not result.get('remove_item', False)] | ||
print("\nFinal Filtered List:", filtered_items) | ||
|
||
return results | ||
|
||
async def filter_item(self, system_prompt: str, user_prompt: str) -> Dict: | ||
response = await self.openai_client.complete_chat([ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": user_prompt} | ||
], max_tokens=self.max_tokens) | ||
|
||
return await self.process_response(response, system_prompt, user_prompt) | ||
|
||
async def process_response(self, response: str, system_prompt: str, user_prompt: str, retry: bool = True) -> Dict: | ||
try: | ||
# Parse the response as JSON | ||
result = json.loads(response) | ||
# Validate against the schema | ||
jsonschema.validate(instance=result, schema=self.schema) | ||
return result | ||
except (json.JSONDecodeError, jsonschema.ValidationError) as e: | ||
if retry: | ||
# Retry once if validation fails | ||
return await self.filter_item(system_prompt, user_prompt) | ||
else: | ||
return {"error": f"Failed to parse response after retry: {str(e)}", "response": response, "item": user_prompt} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,47 @@ | ||
import logging | ||
import http.client | ||
import json | ||
import os | ||
|
||
|
||
class Logger: | ||
def __init__(self, settings_path="../config/settings.json"): | ||
self.settings = self.load_settings(settings_path) | ||
self.log_path = self.settings["log_path"] | ||
os.makedirs(os.path.dirname(self.log_path), exist_ok=True) | ||
logging.basicConfig( | ||
filename=self.log_path, | ||
level=logging.INFO, | ||
format="%(asctime)s - %(levelname)s - %(message)s", | ||
) | ||
|
||
# Create a logger instance | ||
self.logger = logging.getLogger("AgentMLogger") | ||
self.logger.setLevel(logging.DEBUG if self.settings.get("debug", False) else logging.INFO) | ||
|
||
# File handler for logging to a file | ||
file_handler = logging.FileHandler(self.log_path) | ||
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) | ||
self.logger.addHandler(file_handler) | ||
|
||
# Console handler for output to the console | ||
console_handler = logging.StreamHandler() | ||
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) | ||
self.logger.addHandler(console_handler) | ||
|
||
# Enable HTTP-level logging if debug is enabled | ||
if self.settings.get("debug", False): | ||
self.enable_http_debug() | ||
|
||
def load_settings(self, settings_path): | ||
try: | ||
with open(settings_path, "r") as f: | ||
return json.load(f) | ||
except FileNotFoundError: | ||
raise Exception(f"Settings file not found at {settings_path}") | ||
except KeyError as e: | ||
raise Exception(f"Missing key in settings: {e}") | ||
if not os.path.exists(settings_path): | ||
raise FileNotFoundError(f"Settings file not found at {settings_path}") | ||
with open(settings_path, "r") as f: | ||
return json.load(f) | ||
|
||
def enable_http_debug(self): | ||
"""Enable HTTP-level logging for API communication.""" | ||
http.client.HTTPConnection.debuglevel = 1 | ||
logging.getLogger("http.client").setLevel(logging.DEBUG) | ||
logging.getLogger("http.client").propagate = True | ||
|
||
def info(self, message): | ||
logging.info(message) | ||
self.logger.info(message) | ||
|
||
def error(self, message): | ||
logging.error(message) | ||
self.logger.error(message) | ||
print(f"ERROR: {message}") |
Oops, something went wrong.