Skip to content

Commit

Permalink
add gile list example
Browse files Browse the repository at this point in the history
  • Loading branch information
createreadupdate committed Sep 11, 2024
1 parent af9633f commit 2858ac6
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
23 changes: 23 additions & 0 deletions examples/filter_list_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import asyncio
from core.filter_list_agent import FilterListAgent

async def run_filter_list_example():
goal = "Remove items that are unhealthy snacks."
items_to_filter = [
"Apple",
"Chocolate bar",
"Carrot",
"Chips",
"Orange"
]

agent = FilterListAgent(goal=goal, items_to_filter=items_to_filter)
filtered_results = await agent.filter()

print("Original list:", items_to_filter)
print("Filtered results:")
for result in filtered_results:
print(result)

if __name__ == "__main__":
asyncio.run(run_filter_list_example())
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ tiktoken
anyio
trio
openai
jsonschema
76 changes: 76 additions & 0 deletions src/core/filter_list_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import asyncio
import json
import jsonschema
from typing import List, Dict
from .openai_api import OpenAIClient

class FilterListAgent:
def __init__(self, goal: str, items_to_filter: List[str], max_tokens: int = 500, temperature: float = 0.0):
self.goal = goal
self.items = items_to_filter
self.max_tokens = max_tokens
self.temperature = temperature
self.openai_client = OpenAIClient()

# JSON schema for validation
schema = {
"type": "object",
"properties": {
"explanation": {"type": "string"},
"remove_item": {"type": "boolean"}
},
"required": ["explanation", "remove_item"]
}

async def filter(self) -> List[Dict]:
return await self.filter_list(self.items)

async def filter_list(self, items: List[str]) -> List[Dict]:
# System prompt with multi-shot examples to guide the model
system_prompt = (
"You are an assistant tasked with filtering a list of items. The goal is: "
f"{self.goal}. For each item, decide if it should be removed based on whether it is a healthy snack.\n"
"Respond in the following structured format:\n\n"
"Example:\n"
"{\"explanation\": \"The apple is a healthy snack option, as it is low in calories...\",\n"
" \"remove_item\": false}\n\n"
"Example:\n"
"{\"explanation\": \"A chocolate bar is generally considered an unhealthy snack...\",\n"
" \"remove_item\": true}\n\n"
)

tasks = []
for index, item in enumerate(items):
user_prompt = f"Item {index+1}: {item}. Should it be removed? Answer with explanation and 'remove_item': true/false."
tasks.append(self.filter_item(system_prompt, user_prompt))

# Run all tasks in parallel
results = await asyncio.gather(*tasks)

# Show the final list of items that were kept
filtered_items = [self.items[i] for i, result in enumerate(results) if not result.get('remove_item', False)]
print("\nFinal Filtered List:", filtered_items)

return results

async def filter_item(self, system_prompt: str, user_prompt: str) -> Dict:
response = await self.openai_client.complete_chat([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
], max_tokens=self.max_tokens)

return await self.process_response(response, system_prompt, user_prompt)

async def process_response(self, response: str, system_prompt: str, user_prompt: str, retry: bool = True) -> Dict:
try:
# Parse the response as JSON
result = json.loads(response)
# Validate against the schema
jsonschema.validate(instance=result, schema=self.schema)
return result
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
if retry:
# Retry once if validation fails
return await self.filter_item(system_prompt, user_prompt)
else:
return {"error": f"Failed to parse response after retry: {str(e)}", "response": response, "item": user_prompt}

0 comments on commit 2858ac6

Please sign in to comment.