diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..747ffb7 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..14f72bb --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,28 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'agentm-py' +copyright = '2024, Steven Ickman, Jochen Schultz' +author = 'Steven Ickman, Jochen Schultz' +release = '0.1' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [] + +templates_path = ['_templates'] +exclude_patterns = [] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'alabaster' +html_static_path = ['_static'] diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..8f393ce --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,17 @@ +.. agentm-py documentation master file, created by + sphinx-quickstart on Sat Sep 21 01:49:41 2024. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +agentm-py documentation +======================= + +Add your content using ``reStructuredText`` syntax. See the +`reStructuredText `_ +documentation for details. + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + diff --git a/examples/binary_classify_list_example.py b/examples/binary_classify_list_example.py index 8bd5ff9..0f69350 100644 --- a/examples/binary_classify_list_example.py +++ b/examples/binary_classify_list_example.py @@ -1,13 +1,18 @@ import asyncio -from core.binary_classify_list_agent import BinaryClassifyListAgent +from core.binary_classify_list_agent import BinaryClassifyListAgent, BinaryClassifyListInput async def run_binary_classify_list_example(): - items_to_classify = ['Apple', 'Chocolate', 'Carrot'] - criteria = 'Classify each item as either healthy (true) or unhealthy (false)' - agent = BinaryClassifyListAgent(list_to_classify=items_to_classify, criteria=criteria) + input_data = BinaryClassifyListInput( + list_to_classify=['Apple', 'Chocolate', 'Carrot'], + criteria='Classify each item as either healthy (true) or unhealthy (false)', + max_tokens=1000, + temperature=0.0 + ) + + agent = BinaryClassifyListAgent(input_data) classified_items = await agent.classify_list() - print("Original list:", items_to_classify) + print("Original list:", input_data.list_to_classify) print("Binary classified results:", classified_items) if __name__ == "__main__": diff --git a/examples/chain_of_thought_example.py b/examples/chain_of_thought_example.py index c06e3f7..0ffa219 100644 --- a/examples/chain_of_thought_example.py +++ b/examples/chain_of_thought_example.py @@ -1,12 +1,17 @@ import asyncio -from core.chain_of_thought_agent import ChainOfThoughtAgent +from core.chain_of_thought_agent import ChainOfThoughtAgent, ChainOfThoughtInput async def run_chain_of_thought_example(): - question = 'What is the square root of 144?' - agent = ChainOfThoughtAgent(question=question) + input_data = ChainOfThoughtInput( + question='What is the square root of 144?', + max_tokens=1000, + temperature=0.0 + ) + + agent = ChainOfThoughtAgent(input_data) result = await agent.chain_of_thought() - print("Question:", question) + print("Question:", input_data.question) print("Chain of Thought Reasoning:", result) if __name__ == "__main__": diff --git a/examples/classify_list_example.py b/examples/classify_list_example.py index 8505795..12f04c1 100644 --- a/examples/classify_list_example.py +++ b/examples/classify_list_example.py @@ -1,14 +1,18 @@ import asyncio -from core.classify_list_agent import ClassifyListAgent +from core.classify_list_agent import ClassifyListAgent, ClassifyListInput async def run_classify_list_example(): - items_to_classify = ['Apple', 'Chocolate', 'Carrot'] - classification_criteria = 'Classify each item as healthy or unhealthy snack' - agent = ClassifyListAgent(list_to_classify=items_to_classify, classification_criteria=classification_criteria) - classified_items = await agent.classify_list() + input_data = ClassifyListInput( + list_to_classify=["Apple", "Banana", "Carrot"], + classification_criteria="Classify each item as a fruit or vegetable.", + max_tokens=1000 + ) + + agent = ClassifyListAgent(input_data) + classifications = await agent.classify_list() - print("Original list:", items_to_classify) - print("Classified results:", classified_items) + print("Original list:", input_data.list_to_classify) + print("Classified results:", classifications) if __name__ == "__main__": - asyncio.run(run_classify_list_example()) \ No newline at end of file + asyncio.run(run_classify_list_example()) diff --git a/examples/filter_list_example.py b/examples/filter_list_example.py index 745ba2f..9a04b7c 100644 --- a/examples/filter_list_example.py +++ b/examples/filter_list_example.py @@ -1,23 +1,27 @@ import asyncio -from core.filter_list_agent import FilterListAgent +from core.filter_list_agent import FilterListAgent, FilterListInput async def run_filter_list_example(): - goal = "Remove items that are unhealthy snacks." - items_to_filter = [ - "Apple", - "Chocolate bar", - "Carrot", - "Chips", - "Orange" - ] - - agent = FilterListAgent(goal=goal, items_to_filter=items_to_filter) + input_data = FilterListInput( + goal="Remove items that are unhealthy snacks.", + items_to_filter=[ + "Apple", + "Chocolate bar", + "Carrot", + "Chips", + "Orange" + ], + max_tokens=500, + temperature=0.0 + ) + + agent = FilterListAgent(input_data) filtered_results = await agent.filter() - print("Original list:", items_to_filter) + print("Original list:", input_data.items_to_filter) print("Filtered results:") for result in filtered_results: print(result) if __name__ == "__main__": - asyncio.run(run_filter_list_example()) + asyncio.run(run_filter_list_example()) \ No newline at end of file diff --git a/examples/generate_object_example.py b/examples/generate_object_example.py index a1f5aaa..268a8c0 100644 --- a/examples/generate_object_example.py +++ b/examples/generate_object_example.py @@ -1,13 +1,17 @@ import asyncio -from core.generate_object_agent import GenerateObjectAgent +from core.generate_object_agent import GenerateObjectAgent, ObjectGenerationInput async def run_generate_object_example(): - description = "A machine that can sort fruits." - goal = "Generate a high-level design of the machine." - agent = GenerateObjectAgent(object_description=description, goal=goal) + input_data = ObjectGenerationInput( + object_description="A machine that can sort fruits.", + goal="Generate a high-level design of the machine.", + max_tokens=1000 + ) + + agent = GenerateObjectAgent(input_data) generated_object = await agent.generate_object() - print("Object description:", description) + print("Object description:", input_data.object_description) print("Generated object:", generated_object) if __name__ == "__main__": diff --git a/examples/grounded_answer_example.py b/examples/grounded_answer_example.py index 181c991..26d12d7 100644 --- a/examples/grounded_answer_example.py +++ b/examples/grounded_answer_example.py @@ -1,15 +1,19 @@ import asyncio -from core.grounded_answer_agent import GroundedAnswerAgent +from core.grounded_answer_agent import GroundedAnswerAgent, GroundedAnswerInput async def run_grounded_answer_example(): - question = "What is the capital of France?" - context = "France is a country in Western Europe. Paris is its capital and largest city." - instructions = "Ensure the answer is grounded only in the provided context." - agent = GroundedAnswerAgent(question=question, context=context, instructions=instructions) - result = await agent.answer() + input_data = GroundedAnswerInput( + question="What is the capital of France?", + context="France is a country in Western Europe known for its wine and cuisine. The capital is a major global center for art, fashion, and culture.", + instructions="", + max_tokens=1000 + ) + + agent = GroundedAnswerAgent(input_data) + answer = await agent.answer() - print("Question:", question) - print("Result:", result) + print("Question:", input_data.question) + print("Answer:", answer) if __name__ == "__main__": - asyncio.run(run_grounded_answer_example()) \ No newline at end of file + asyncio.run(run_grounded_answer_example()) diff --git a/examples/map_list_example.py b/examples/map_list_example.py index 9dc3fe5..32d5bb9 100644 --- a/examples/map_list_example.py +++ b/examples/map_list_example.py @@ -1,13 +1,17 @@ import asyncio -from core.map_list_agent import MapListAgent +from core.map_list_agent import MapListAgent, MapListInput async def run_map_list_example(): - items_to_map = ['Apple', 'Banana', 'Carrot'] - transformation = 'Convert all items to uppercase' - agent = MapListAgent(list_to_map=items_to_map, transformation=transformation) + input_data = MapListInput( + list_to_map=['Apple', 'Banana', 'Carrot'], + transformation='Convert all items to uppercase', + max_tokens=1000 + ) + + agent = MapListAgent(input_data) transformed_items = await agent.map_list() - print("Original list:", items_to_map) + print("Original list:", input_data.list_to_map) print("Transformed list:", transformed_items) if __name__ == "__main__": diff --git a/examples/project_list_example.py b/examples/project_list_example.py index 27aab4d..965082a 100644 --- a/examples/project_list_example.py +++ b/examples/project_list_example.py @@ -1,13 +1,17 @@ import asyncio -from core.project_list_agent import ProjectListAgent +from core.project_list_agent import ProjectListAgent, ProjectListInput async def run_project_list_example(): - items_to_project = ['Apple', 'Banana', 'Carrot'] - projection_rule = 'Project these items as their vitamin content' - agent = ProjectListAgent(list_to_project=items_to_project, projection_rule=projection_rule) + input_data = ProjectListInput( + list_to_project=['Apple', 'Banana', 'Carrot'], + projection_rule='Project these items as their vitamin content', + max_tokens=1000 + ) + + agent = ProjectListAgent(input_data) projected_items = await agent.project_list() - print("Original list:", items_to_project) + print("Original list:", input_data.list_to_project) print("Projected results:", projected_items) if __name__ == "__main__": diff --git a/examples/reduce_list_example.py b/examples/reduce_list_example.py index 049f339..01b070b 100644 --- a/examples/reduce_list_example.py +++ b/examples/reduce_list_example.py @@ -1,14 +1,18 @@ import asyncio -from core.reduce_list_agent import ReduceListAgent +from core.reduce_list_agent import ReduceListAgent, ReduceListInput async def run_reduce_list_example(): - items_to_reduce = ['Banana', 'Apple', 'Carrot'] - reduction_goal = 'Reduce these items to a single word representing their nutritional value' - agent = ReduceListAgent(list_to_reduce=items_to_reduce, reduction_goal=reduction_goal) + input_data = ReduceListInput( + list_to_reduce=["Apple", "Banana", "Carrot"], + reduction_goal="Reduce each item to its first letter.", + max_tokens=1000 + ) + + agent = ReduceListAgent(input_data) reduced_items = await agent.reduce_list() - print("Original list:", items_to_reduce) + print("Original list:", input_data.list_to_reduce) print("Reduced results:", reduced_items) if __name__ == "__main__": - asyncio.run(run_reduce_list_example()) \ No newline at end of file + asyncio.run(run_reduce_list_example()) diff --git a/examples/summarize_list_example.py b/examples/summarize_list_example.py index f96b16a..f972435 100644 --- a/examples/summarize_list_example.py +++ b/examples/summarize_list_example.py @@ -1,12 +1,19 @@ import asyncio -from core.summarize_list_agent import SummarizeListAgent +from core.summarize_list_agent import SummarizeListAgent, SummarizeListInput async def run_summarize_list_example(): - items_to_summarize = ['The quick brown fox jumps over the lazy dog.', 'Python is a popular programming language.'] - agent = SummarizeListAgent(list_to_summarize=items_to_summarize) + input_data = SummarizeListInput( + list_to_summarize=[ + 'The quick brown fox jumps over the lazy dog.', + 'Python is a popular programming language.' + ], + max_tokens=1000 + ) + + agent = SummarizeListAgent(input_data) summaries = await agent.summarize_list() - print("Original list:", items_to_summarize) + print("Original list:", input_data.list_to_summarize) print("Summarized results:", summaries) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index bd5b8c3..e472602 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,7 @@ anyio trio openai jsonschema +sphinx +sphinx-rtd-theme +myst-parser +pydantic \ No newline at end of file diff --git a/src/core/binary_classify_list_agent.py b/src/core/binary_classify_list_agent.py index 87bbc6b..a3c879e 100644 --- a/src/core/binary_classify_list_agent.py +++ b/src/core/binary_classify_list_agent.py @@ -1,18 +1,54 @@ +from pydantic import BaseModel, Field import asyncio from typing import List, Dict from .openai_api import OpenAIClient from .logging import Logger # Using correct logging abstraction +class BinaryClassifyListInput(BaseModel): + list_to_classify: List[str] = Field(..., description="The list of items to classify") + criteria: str = Field(..., description="The criteria for binary classification") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + temperature: float = Field(0.0, description="Sampling temperature for the OpenAI model") + class BinaryClassifyListAgent: - def __init__(self, list_to_classify: List[str], criteria: str, max_tokens: int = 1000, temperature: float = 0.0): - self.list_to_classify = list_to_classify - self.criteria = criteria - self.max_tokens = max_tokens - self.temperature = temperature + """ + A class to classify items in a list based on binary criteria using the OpenAI API. + + Attributes: + list_to_classify (List[str]): The list of items to classify. + criteria (str): The criteria for binary classification. + max_tokens (int): The maximum number of tokens to generate. + temperature (float): Sampling temperature for the OpenAI model. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + logger (Logger): An instance of Logger to log classification requests and responses. + + Methods: + classify_list(): Classifies the entire list of items. + classify_item(user_prompt): Classifies a single item based on the criteria. + """ + + def __init__(self, data: BinaryClassifyListInput): + """ + Constructs all the necessary attributes for the BinaryClassifyListAgent object. + + Args: + data (BinaryClassifyListInput): An instance of BinaryClassifyListInput containing + the list of items, criteria, max_tokens, and temperature. + """ + self.list_to_classify = data.list_to_classify + self.criteria = data.criteria + self.max_tokens = data.max_tokens + self.temperature = data.temperature self.openai_client = OpenAIClient() self.logger = Logger() async def classify_list(self) -> List[Dict]: + """ + Classifies the entire list based on the provided items and criteria. + + Returns: + List[Dict]: A list of dictionaries with the classification results. + """ tasks = [] for item in self.list_to_classify: user_prompt = f"Based on the following criteria '{self.criteria}', classify the item '{item}' as true or false." @@ -22,6 +58,15 @@ async def classify_list(self) -> List[Dict]: return results async def classify_item(self, user_prompt: str) -> Dict: + """ + Classifies a single item based on the criteria. + + Args: + user_prompt (str): The prompt describing the classification criteria and item. + + Returns: + Dict: A dictionary with the classification result. + """ system_prompt = "You are an assistant tasked with binary classification of items." self.logger.info(f"Classifying item: {user_prompt}") # Logging the classification request diff --git a/src/core/chain_of_thought_agent.py b/src/core/chain_of_thought_agent.py index 2047c46..f99c57a 100644 --- a/src/core/chain_of_thought_agent.py +++ b/src/core/chain_of_thought_agent.py @@ -1,15 +1,46 @@ -import asyncio +from pydantic import BaseModel, Field from typing import List from .openai_api import OpenAIClient +class ChainOfThoughtInput(BaseModel): + question: str = Field(..., description="The question to solve using chain of thought reasoning") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + temperature: float = Field(0.0, description="Sampling temperature for the OpenAI model") + class ChainOfThoughtAgent: - def __init__(self, question: str, max_tokens: int = 1000, temperature: float = 0.0): - self.question = question - self.max_tokens = max_tokens - self.temperature = temperature + """ + A class to solve problems using the 'chain of thought' reasoning process via the OpenAI API. + + Attributes: + question (str): The question to solve. + max_tokens (int): The maximum number of tokens to generate. + temperature (float): Sampling temperature for the OpenAI model. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + chain_of_thought(): Solves the question using chain of thought reasoning. + """ + + def __init__(self, data: ChainOfThoughtInput): + """ + Constructs all the necessary attributes for the ChainOfThoughtAgent object. + + Args: + data (ChainOfThoughtInput): An instance of ChainOfThoughtInput containing + the question, max_tokens, and temperature. + """ + self.question = data.question + self.max_tokens = data.max_tokens + self.temperature = data.temperature self.openai_client = OpenAIClient() async def chain_of_thought(self) -> str: + """ + Solves the question using chain of thought reasoning. + + Returns: + str: The step-by-step reasoning process and solution. + """ system_prompt = "You are an assistant tasked with solving problems using the 'chain of thought' reasoning process." user_prompt = f"Solve the following problem step-by-step: {self.question}" diff --git a/src/core/classify_list_agent.py b/src/core/classify_list_agent.py index a46799c..a9f1ecd 100644 --- a/src/core/classify_list_agent.py +++ b/src/core/classify_list_agent.py @@ -1,15 +1,48 @@ +from pydantic import BaseModel, Field import asyncio from typing import List, Dict from .openai_api import OpenAIClient +class ClassifyListInput(BaseModel): + list_to_classify: List[str] = Field(..., description="The list of items to classify") + classification_criteria: str = Field(..., description="The criteria for classifying the items") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + class ClassifyListAgent: - def __init__(self, list_to_classify: List[str], classification_criteria: str, max_tokens: int = 1000): - self.list_to_classify = list_to_classify - self.classification_criteria = classification_criteria - self.max_tokens = max_tokens + """ + A class to classify items in a list based on given criteria using the OpenAI API. + + Attributes: + list_to_classify (List[str]): The list of items to classify. + classification_criteria (str): The criteria for classifying the items. + max_tokens (int): The maximum number of tokens to generate. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + classify_list(): Classifies the entire list of items. + classify_item(user_prompt): Classifies a single item based on the classification criteria. + """ + + def __init__(self, data: ClassifyListInput): + """ + Constructs all the necessary attributes for the ClassifyListAgent object. + + Args: + data (ClassifyListInput): An instance of ClassifyListInput containing + the list of items, classification criteria, and max_tokens. + """ + self.list_to_classify = data.list_to_classify + self.classification_criteria = data.classification_criteria + self.max_tokens = data.max_tokens self.openai_client = OpenAIClient() async def classify_list(self) -> List[Dict]: + """ + Classifies the entire list based on the provided items and classification criteria. + + Returns: + List[Dict]: A list of dictionaries with the classification results. + """ tasks = [] for item in self.list_to_classify: user_prompt = f"Classify the item '{item}' according to the following criteria: {self.classification_criteria}." @@ -19,7 +52,16 @@ async def classify_list(self) -> List[Dict]: return results async def classify_item(self, user_prompt: str) -> Dict: - system_prompt = f"You are an assistant tasked with classifying items based on the given criteria." + """ + Classifies a single item based on the classification criteria. + + Args: + user_prompt (str): The prompt describing the item and classification criteria. + + Returns: + Dict: A dictionary with the classification result. + """ + system_prompt = "You are an assistant tasked with classifying items based on the given criteria." response = await self.openai_client.complete_chat([ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} diff --git a/src/core/compose_prompt.py b/src/core/compose_prompt.py index 248de04..3cbd1e4 100644 --- a/src/core/compose_prompt.py +++ b/src/core/compose_prompt.py @@ -1,7 +1,16 @@ import re - def compose_prompt(template: str, variables: dict) -> str: + """ + Composes a prompt by substituting variables in a template string. + + Args: + template (str): The template string containing placeholders in the form of {{variable_name}}. + variables (dict): A dictionary where keys are variable names and values are the replacements. + + Returns: + str: The composed string with all placeholders replaced by their corresponding values. + """ return re.sub( r"{{\s*([^}\s]+)\s*}}", lambda match: str(variables.get(match.group(1), "")), diff --git a/src/core/concurrency.py b/src/core/concurrency.py index a691e54..421718a 100644 --- a/src/core/concurrency.py +++ b/src/core/concurrency.py @@ -1,16 +1,50 @@ import asyncio - class Semaphore: + """ + A class that implements an asynchronous semaphore for controlling access to a limited number of concurrent tasks. + + Attributes: + semaphore (asyncio.Semaphore): An asyncio semaphore to limit concurrent tasks. + + Methods: + __aenter__(): Acquires the semaphore. + __aexit__(): Releases the semaphore. + call_function(func, *args, **kwargs): Calls a function while respecting the semaphore limits. + """ + def __init__(self, max_concurrent_tasks): + """ + Constructs the Semaphore object with a maximum number of concurrent tasks. + + Args: + max_concurrent_tasks (int): The maximum number of tasks that can run concurrently. + """ self.semaphore = asyncio.Semaphore(max_concurrent_tasks) async def __aenter__(self): + """ + Acquires the semaphore to enter a protected code block. + """ await self.semaphore.acquire() async def __aexit__(self, exc_type, exc_val, exc_tb): + """ + Releases the semaphore after leaving a protected code block. + """ self.semaphore.release() async def call_function(self, func, *args, **kwargs): + """ + Calls a function while respecting the semaphore limits. + + Args: + func (Callable): The function to call. + *args: Positional arguments for the function. + **kwargs: Keyword arguments for the function. + + Returns: + Any: The result of the function call. + """ async with self.semaphore: return await func(*args, **kwargs) diff --git a/src/core/filter_list_agent.py b/src/core/filter_list_agent.py index ecdb522..636cc5e 100644 --- a/src/core/filter_list_agent.py +++ b/src/core/filter_list_agent.py @@ -1,18 +1,35 @@ +from pydantic import BaseModel, Field import asyncio import json import jsonschema from typing import List, Dict from .openai_api import OpenAIClient +class FilterListInput(BaseModel): + goal: str = Field(..., description="The goal for filtering the list") + items_to_filter: List[str] = Field(..., description="The list of items to filter") + max_tokens: int = Field(500, description="The maximum number of tokens to generate") + temperature: float = Field(0.0, description="Sampling temperature for the OpenAI model") + class FilterListAgent: - def __init__(self, goal: str, items_to_filter: List[str], max_tokens: int = 500, temperature: float = 0.0): - self.goal = goal - self.items = items_to_filter - self.max_tokens = max_tokens - self.temperature = temperature - self.openai_client = OpenAIClient() + """ + A class to filter items in a list based on a given goal using the OpenAI API. + + Attributes: + goal (str): The goal for filtering the list. + items (List[str]): The list of items to filter. + max_tokens (int): The maximum number of tokens to generate. + temperature (float): Sampling temperature for the OpenAI model. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + schema (dict): JSON schema to validate the API's response format. + + Methods: + filter(): Filters the entire list of items. + filter_list(items): Filters a given list of items. + filter_item(system_prompt, user_prompt): Filters a single item. + process_response(response, system_prompt, user_prompt, retry): Processes and validates the API response. + """ - # JSON schema for validation schema = { "type": "object", "properties": { @@ -22,11 +39,39 @@ def __init__(self, goal: str, items_to_filter: List[str], max_tokens: int = 500, "required": ["explanation", "remove_item"] } + def __init__(self, data: FilterListInput): + """ + Constructs all the necessary attributes for the FilterListAgent object. + + Args: + data (FilterListInput): An instance of FilterListInput containing + the goal, items to filter, max_tokens, and temperature. + """ + self.goal = data.goal + self.items = data.items_to_filter + self.max_tokens = data.max_tokens + self.temperature = data.temperature + self.openai_client = OpenAIClient() + async def filter(self) -> List[Dict]: + """ + Filters the entire list based on the provided items and goal. + + Returns: + List[Dict]: A list of dictionaries with the filtering results. + """ return await self.filter_list(self.items) async def filter_list(self, items: List[str]) -> List[Dict]: - # System prompt with multi-shot examples to guide the model + """ + Filters a given list of items based on the goal. + + Args: + items (List[str]): The list of items to filter. + + Returns: + List[Dict]: A list of dictionaries with the filtering results. + """ system_prompt = ( "You are an assistant tasked with filtering a list of items. The goal is: " f"{self.goal}. For each item, decide if it should be removed based on whether it is a healthy snack.\n" @@ -44,16 +89,24 @@ async def filter_list(self, items: List[str]) -> List[Dict]: user_prompt = f"Item {index+1}: {item}. Should it be removed? Answer with explanation and 'remove_item': true/false." tasks.append(self.filter_item(system_prompt, user_prompt)) - # Run all tasks in parallel results = await asyncio.gather(*tasks) - # Show the final list of items that were kept filtered_items = [self.items[i] for i, result in enumerate(results) if not result.get('remove_item', False)] print("\nFinal Filtered List:", filtered_items) return results async def filter_item(self, system_prompt: str, user_prompt: str) -> Dict: + """ + Filters a single item based on the goal. + + Args: + system_prompt (str): The system prompt to guide the API. + user_prompt (str): The user prompt to describe the item to be filtered. + + Returns: + Dict: A dictionary with the filtering result. + """ response = await self.openai_client.complete_chat([ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} @@ -62,15 +115,24 @@ async def filter_item(self, system_prompt: str, user_prompt: str) -> Dict: return await self.process_response(response, system_prompt, user_prompt) async def process_response(self, response: str, system_prompt: str, user_prompt: str, retry: bool = True) -> Dict: + """ + Processes and validates the API response. + + Args: + response (str): The API's response to process. + system_prompt (str): The system prompt used for the API request. + user_prompt (str): The user prompt used for the API request. + retry (bool): Whether to retry the request if validation fails. + + Returns: + Dict: A dictionary containing the validated response or an error. + """ try: - # Parse the response as JSON result = json.loads(response) - # Validate against the schema jsonschema.validate(instance=result, schema=self.schema) return result except (json.JSONDecodeError, jsonschema.ValidationError) as e: if retry: - # Retry once if validation fails return await self.filter_item(system_prompt, user_prompt) else: return {"error": f"Failed to parse response after retry: {str(e)}", "response": response, "item": user_prompt} diff --git a/src/core/generate_object_agent.py b/src/core/generate_object_agent.py index 52c61df..5a7e6c4 100644 --- a/src/core/generate_object_agent.py +++ b/src/core/generate_object_agent.py @@ -1,15 +1,46 @@ -import asyncio +from pydantic import BaseModel, Field from typing import Dict from .openai_api import OpenAIClient +class ObjectGenerationInput(BaseModel): + object_description: str = Field(..., description="A description of the object to generate") + goal: str = Field(..., description="The goal of the generation process") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + class GenerateObjectAgent: - def __init__(self, object_description: str, goal: str, max_tokens: int = 1000): - self.object_description = object_description - self.goal = goal - self.max_tokens = max_tokens + """ + A class to generate objects based on a given description and goal using the OpenAI API. + + Attributes: + object_description (str): A description of the object to generate. + goal (str): The goal of the generation process. + max_tokens (int): The maximum number of tokens to generate. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + generate_object(): Generates an object based on the description and goal. + """ + + def __init__(self, data: ObjectGenerationInput): + """ + Constructs all the necessary attributes for the GenerateObjectAgent object. + + Args: + data (ObjectGenerationInput): An instance of ObjectGenerationInput containing + the object description, goal, and max_tokens. + """ + self.object_description = data.object_description + self.goal = data.goal + self.max_tokens = data.max_tokens self.openai_client = OpenAIClient() async def generate_object(self) -> Dict: + """ + Generates an object based on the given description and goal. + + Returns: + dict: A dictionary containing the original object description and the generated object. + """ system_prompt = f"You are an assistant tasked with generating objects based on a given description. The goal is: {self.goal}." user_prompt = f"Generate an object based on the following description: {self.object_description}." diff --git a/src/core/grounded_answer_agent.py b/src/core/grounded_answer_agent.py index 7763b52..f6d6a98 100644 --- a/src/core/grounded_answer_agent.py +++ b/src/core/grounded_answer_agent.py @@ -1,16 +1,33 @@ +from pydantic import BaseModel, Field import asyncio import json import jsonschema from typing import Dict from .openai_api import OpenAIClient +class GroundedAnswerInput(BaseModel): + question: str = Field(..., description="The question to answer based on the provided context") + context: str = Field(..., description="The context information to base the answer on") + instructions: str = Field('', description="Additional instructions for answering the question") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + class GroundedAnswerAgent: - def __init__(self, question: str, context: str, instructions: str = '', max_tokens: int = 1000): - self.question = question - self.context = context - self.instructions = instructions - self.max_tokens = max_tokens - self.openai_client = OpenAIClient() + """ + A class to provide grounded answers based on a given context using the OpenAI API. + + Attributes: + question (str): The question to answer based on the provided context. + context (str): The context information to base the answer on. + instructions (str): Additional instructions for answering the question. + max_tokens (int): The maximum number of tokens to generate. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + schema (dict): JSON schema to validate the API's response format. + + Methods: + answer(): Provides a grounded answer based on the context. + grounded_answer(): Generates the grounded answer using the API. + process_response(response): Processes and validates the API response. + """ # JSON schema for validation schema = { @@ -23,11 +40,43 @@ def __init__(self, question: str, context: str, instructions: str = '', max_toke "additionalProperties": False } + def __init__(self, data: GroundedAnswerInput): + """ + Constructs all the necessary attributes for the GroundedAnswerAgent object. + + Args: + data (GroundedAnswerInput): An instance of GroundedAnswerInput containing + the question, context, instructions, and max_tokens. + """ + self.question = data.question + self.context = data.context + self.instructions = data.instructions + self.max_tokens = data.max_tokens + self.openai_client = OpenAIClient() + async def answer(self) -> Dict: + """ + Provides a grounded answer based on the provided context. + + Returns: + Dict: The grounded answer and explanation. + """ return await self.grounded_answer() async def grounded_answer(self) -> Dict: - system_prompt = f"\n{self.context}\n\n\nBase your answer only on the information provided in the above .\nReturn your answer using the JSON below. \nDo not directly mention that you're using the context in your answer.\n\n\n{{\"explanation\": \"\", \"answer\": \"\"}}{self.instructions}" + """ + Generates the grounded answer using the API. + + Returns: + Dict: The grounded answer and explanation. + """ + system_prompt = ( + f"\n{self.context}\n\n" + "\nBase your answer only on the information provided in the above .\n" + "Return your answer using the JSON below.\n" + "Do not directly mention that you're using the context in your answer.\n\n" + f"\n{{\"explanation\": \"\", \"answer\": \"\"}}{self.instructions}" + ) user_prompt = self.question @@ -39,6 +88,15 @@ async def grounded_answer(self) -> Dict: return await self.process_response(response) async def process_response(self, response: str) -> Dict: + """ + Processes and validates the API response. + + Args: + response (str): The API's response to process. + + Returns: + Dict: The validated response or an error. + """ try: result = json.loads(response) jsonschema.validate(instance=result, schema=self.schema) diff --git a/src/core/log_complete_prompt.py b/src/core/log_complete_prompt.py index 99a8d86..2248f0d 100644 --- a/src/core/log_complete_prompt.py +++ b/src/core/log_complete_prompt.py @@ -1,12 +1,34 @@ from core.logging import Logger - class LogCompletePrompt: + """ + A class that logs the completion status of a prompt. + + Attributes: + complete_prompt_func (Callable): The function that completes the prompt. + logger (Logger): An instance of Logger to handle logging. + + Methods: + complete_prompt(): Executes the prompt completion and logs the result. + """ + def __init__(self, complete_prompt_func): + """ + Constructs all the necessary attributes for the LogCompletePrompt object. + + Args: + complete_prompt_func (Callable): The function that completes the prompt. + """ self.complete_prompt_func = complete_prompt_func self.logger = Logger() async def complete_prompt(self, *args, **kwargs): + """ + Executes the prompt completion and logs whether it was successful or not. + + Returns: + dict: The result from the prompt completion function. + """ result = await self.complete_prompt_func(*args, **kwargs) if result["completed"]: diff --git a/src/core/logging.py b/src/core/logging.py index 48a06e7..a0e2e72 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -4,7 +4,26 @@ from datetime import datetime class Logger: + """ + A logger class that handles logging messages to a file and the console. + + Attributes: + settings (dict): A dictionary containing settings loaded from a JSON file. + logger (logging.Logger): An instance of Python's standard logging.Logger. + + Methods: + load_settings(settings_path): Loads settings from a specified JSON file. + info(message): Logs an informational message. + error(message): Logs an error message. + """ + def __init__(self, settings_path=None): + """ + Constructs the Logger object and initializes the logging configuration. + + Args: + settings_path (str): The path to the settings JSON file. + """ if settings_path is None: settings_path = os.path.join(os.path.dirname(__file__), '../../config/settings.json') self.settings = self.load_settings(settings_path) @@ -15,15 +34,36 @@ def __init__(self, settings_path=None): self.logger = logging.getLogger() def load_settings(self, settings_path): + """ + Loads settings from a JSON file. + + Args: + settings_path (str): The path to the settings JSON file. + + Returns: + dict: A dictionary containing the settings. + """ if not os.path.exists(settings_path): raise FileNotFoundError(f"Settings file not found at {settings_path}") with open(settings_path, 'r') as f: return json.load(f) def info(self, message): + """ + Logs an informational message. + + Args: + message (str): The message to log. + """ print(message) self.logger.info(message) def error(self, message): + """ + Logs an error message. + + Args: + message (str): The message to log. + """ print(message) self.logger.error(message) diff --git a/src/core/map_list_agent.py b/src/core/map_list_agent.py index 452ac55..c5a101c 100644 --- a/src/core/map_list_agent.py +++ b/src/core/map_list_agent.py @@ -1,26 +1,67 @@ +from pydantic import BaseModel, Field import asyncio from typing import List from .openai_api import OpenAIClient +class MapListInput(BaseModel): + list_to_map: List[str] = Field(..., description="The list of items to transform") + transformation: str = Field(..., description="The transformation rule to apply to each item") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + class MapListAgent: - def __init__(self, list_to_map: List[str], transformation: str, max_tokens: int = 1000): - self.list_to_map = list_to_map - self.transformation = transformation - self.max_tokens = max_tokens + """ + A class to apply a transformation to each item in a list using the OpenAI API. + + Attributes: + list_to_map (List[str]): The list of items to transform. + transformation (str): The transformation rule to apply. + max_tokens (int): The maximum number of tokens to generate. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + map_list(): Transforms the entire list based on the transformation rule. + apply_transformation(user_prompt): Applies the transformation to a single item. + """ + + def __init__(self, data: MapListInput): + """ + Constructs all the necessary attributes for the MapListAgent object. + + Args: + data (MapListInput): An instance of MapListInput containing + the list of items, transformation rule, and max_tokens. + """ + self.list_to_map = data.list_to_map + self.transformation = data.transformation + self.max_tokens = data.max_tokens self.openai_client = OpenAIClient() async def map_list(self) -> List[str]: + """ + Transforms the entire list based on the provided items and transformation rule. + + Returns: + List[str]: A list of transformed items. + """ tasks = [] for index, item in enumerate(self.list_to_map): user_prompt = f"Transform '{item}' as per the following rule: {self.transformation}." tasks.append(self.apply_transformation(user_prompt)) - # Run all tasks in parallel results = await asyncio.gather(*tasks) return results async def apply_transformation(self, user_prompt: str) -> str: - system_prompt = f"You are an assistant tasked with transforming list items according to a rule." + """ + Applies the transformation to a single item. + + Args: + user_prompt (str): The prompt describing the transformation rule and item. + + Returns: + str: The transformed item. + """ + system_prompt = "You are an assistant tasked with transforming list items according to a rule." response = await self.openai_client.complete_chat([ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} diff --git a/src/core/openai_api.py b/src/core/openai_api.py index 09b8d30..dd35b45 100644 --- a/src/core/openai_api.py +++ b/src/core/openai_api.py @@ -3,7 +3,24 @@ import os class OpenAIClient: + """ + A client for interacting with the OpenAI API. + + Attributes: + logger (Logger): An instance of Logger for logging API interactions and errors. + client (OpenAI): An instance of the OpenAI API client. + + Methods: + complete_chat(messages, model, max_tokens): Sends a chat completion request to the OpenAI API. + """ + def __init__(self, settings_path=None): + """ + Constructs the OpenAIClient object and initializes the API client. + + Args: + settings_path (str): The path to the settings JSON file containing the API key. + """ if settings_path is None: settings_path = os.path.join(os.path.dirname(__file__), '../../config/settings.json') @@ -12,6 +29,20 @@ def __init__(self, settings_path=None): self.client = OpenAI(api_key=settings["openai_api_key"]) async def complete_chat(self, messages, model="gpt-4o-mini", max_tokens=1500): + """ + Sends a chat completion request to the OpenAI API. + + Args: + messages (list): A list of message dicts for the chat completion. + model (str): The model name to use for the completion. + max_tokens (int): The maximum number of tokens to generate. + + Returns: + str: The generated content from the chat completion. + + Raises: + BadRequestError: If there is an issue with the request to the OpenAI API. + """ try: response = self.client.chat.completions.create( model=model, diff --git a/src/core/parallel_complete_prompt.py b/src/core/parallel_complete_prompt.py index caf9f05..f001b6e 100644 --- a/src/core/parallel_complete_prompt.py +++ b/src/core/parallel_complete_prompt.py @@ -1,17 +1,46 @@ import asyncio from .concurrency import Semaphore - class ParallelCompletePrompt: + """ + A class to handle parallel execution of prompt completion functions with concurrency control. + + Attributes: + complete_prompt_func (Callable): The function to complete the prompt. + parallel_completions (int): The number of prompts to complete in parallel. + should_continue_func (Callable): A function to determine if the operation should continue. + semaphore (Semaphore): A Semaphore to control concurrency. + + Methods: + complete_prompt(*args, **kwargs): Executes the prompt completion function in parallel. + """ + def __init__( self, complete_prompt_func, parallel_completions=1, should_continue_func=None ): + """ + Constructs the ParallelCompletePrompt object and initializes concurrency control. + + Args: + complete_prompt_func (Callable): The function to complete the prompt. + parallel_completions (int): The number of prompts to complete in parallel. + should_continue_func (Callable): A function to determine if the operation should continue. + """ self.complete_prompt_func = complete_prompt_func self.parallel_completions = parallel_completions self.should_continue_func = should_continue_func or (lambda: True) self.semaphore = Semaphore(parallel_completions) async def complete_prompt(self, *args, **kwargs): + """ + Executes the prompt completion function in parallel, respecting concurrency limits. + + Raises: + asyncio.CancelledError: If the operation is cancelled by the should_continue_func. + + Returns: + Any: The result from the prompt completion function. + """ async with self.semaphore: if not self.should_continue_func(): raise asyncio.CancelledError("Operation cancelled.") diff --git a/src/core/project_list_agent.py b/src/core/project_list_agent.py index 0b75828..22a8e36 100644 --- a/src/core/project_list_agent.py +++ b/src/core/project_list_agent.py @@ -1,15 +1,48 @@ -import asyncio +import asyncio # <-- Import asyncio here +from pydantic import BaseModel, Field from typing import List, Dict from .openai_api import OpenAIClient +class ProjectListInput(BaseModel): + list_to_project: List[str] = Field(..., description="The list of items to project") + projection_rule: str = Field(..., description="The rule to apply for projection") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + class ProjectListAgent: - def __init__(self, list_to_project: List[str], projection_rule: str, max_tokens: int = 1000): - self.list_to_project = list_to_project - self.projection_rule = projection_rule - self.max_tokens = max_tokens + """ + A class to project items in a list based on a given rule using the OpenAI API. + + Attributes: + list_to_project (List[str]): The list of items to project. + projection_rule (str): The rule to apply for projection. + max_tokens (int): The maximum number of tokens to generate. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + project_list(): Projects the entire list based on the projection rule. + project_item(): Projects a single item based on the projection rule. + """ + + def __init__(self, data: ProjectListInput): + """ + Constructs all the necessary attributes for the ProjectListAgent object. + + Args: + data (ProjectListInput): An instance of ProjectListInput containing + the list of items, projection rule, and max_tokens. + """ + self.list_to_project = data.list_to_project + self.projection_rule = data.projection_rule + self.max_tokens = data.max_tokens self.openai_client = OpenAIClient() async def project_list(self) -> List[Dict]: + """ + Projects the entire list based on the given projection rule. + + Returns: + List[Dict]: A list of dictionaries with the original items and their projections. + """ tasks = [] for item in self.list_to_project: user_prompt = f"Project the following item based on the rule '{self.projection_rule}': {item}." @@ -19,6 +52,15 @@ async def project_list(self) -> List[Dict]: return results async def project_item(self, user_prompt: str) -> Dict: + """ + Projects a single item based on the given rule. + + Args: + user_prompt (str): The prompt to send to the OpenAI API. + + Returns: + Dict: A dictionary with the original item and its projection. + """ system_prompt = "You are an assistant tasked with projecting items based on a specific rule." response = await self.openai_client.complete_chat([ {"role": "system", "content": system_prompt}, diff --git a/src/core/prompt_generation.py b/src/core/prompt_generation.py index d255e9b..6d252cd 100644 --- a/src/core/prompt_generation.py +++ b/src/core/prompt_generation.py @@ -1,9 +1,35 @@ class PromptGenerator: + """ + A class to manage and generate combined prompts. + + Attributes: + prompts (list): A list to store individual prompts. + + Methods: + add_prompt(prompt): Adds a prompt to the list. + generate_combined_prompt(): Generates a combined prompt from all stored prompts. + """ + def __init__(self): + """ + Constructs the PromptGenerator object and initializes the prompt list. + """ self.prompts = [] def add_prompt(self, prompt): + """ + Adds a prompt to the list of prompts. + + Args: + prompt (str): The prompt to add to the list. + """ self.prompts.append(prompt) def generate_combined_prompt(self): + """ + Generates a combined prompt by joining all prompts with a newline separator. + + Returns: + str: The combined prompt. + """ return "\n".join(self.prompts) diff --git a/src/core/reduce_list_agent.py b/src/core/reduce_list_agent.py index 75591bc..4269051 100644 --- a/src/core/reduce_list_agent.py +++ b/src/core/reduce_list_agent.py @@ -1,15 +1,48 @@ +from pydantic import BaseModel, Field import asyncio from typing import List, Dict from .openai_api import OpenAIClient +class ReduceListInput(BaseModel): + list_to_reduce: List[str] = Field(..., description="The list of items to reduce") + reduction_goal: str = Field(..., description="The goal for reducing the items") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + class ReduceListAgent: - def __init__(self, list_to_reduce: List[str], reduction_goal: str, max_tokens: int = 1000): - self.list_to_reduce = list_to_reduce - self.reduction_goal = reduction_goal - self.max_tokens = max_tokens + """ + A class to reduce items in a list based on a given goal using the OpenAI API. + + Attributes: + list_to_reduce (List[str]): The list of items to reduce. + reduction_goal (str): The goal for reducing the items. + max_tokens (int): The maximum number of tokens to generate. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + reduce_list(): Reduces the entire list of items. + reduce_item(user_prompt): Reduces a single item based on the reduction goal. + """ + + def __init__(self, data: ReduceListInput): + """ + Constructs all the necessary attributes for the ReduceListAgent object. + + Args: + data (ReduceListInput): An instance of ReduceListInput containing + the list of items, reduction goal, and max_tokens. + """ + self.list_to_reduce = data.list_to_reduce + self.reduction_goal = data.reduction_goal + self.max_tokens = data.max_tokens self.openai_client = OpenAIClient() async def reduce_list(self) -> List[Dict]: + """ + Reduces the entire list based on the provided items and reduction goal. + + Returns: + List[Dict]: A list of dictionaries with the reduction results. + """ tasks = [] for item in self.list_to_reduce: user_prompt = f"Reduce the item '{item}' to achieve the goal: {self.reduction_goal}." @@ -19,6 +52,15 @@ async def reduce_list(self) -> List[Dict]: return results async def reduce_item(self, user_prompt: str) -> Dict: + """ + Reduces a single item based on the reduction goal. + + Args: + user_prompt (str): The prompt describing the item and reduction goal. + + Returns: + Dict: A dictionary with the reduced item. + """ system_prompt = "You are an assistant tasked with reducing items to achieve a specific goal." response = await self.openai_client.complete_chat([ {"role": "system", "content": system_prompt}, diff --git a/src/core/sort_list_agent.py b/src/core/sort_list_agent.py index 5aee7de..7b33e7f 100644 --- a/src/core/sort_list_agent.py +++ b/src/core/sort_list_agent.py @@ -1,50 +1,93 @@ +from pydantic import BaseModel, Field import asyncio from typing import List from .openai_api import OpenAIClient +class SortListInput(BaseModel): + goal: str = Field(..., description="The goal for sorting the list") + list_to_sort: List[str] = Field(..., description="The list of items to sort") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + temperature: float = Field(0.0, description="Sampling temperature for the OpenAI model") + log_explanations: bool = Field(False, description="Whether to log explanations of sorting decisions") + class SortListAgent: - def __init__(self, goal: str, list_to_sort: List[str], max_tokens: int = 1000, temperature: float = 0.0, log_explanations: bool = False): - self.goal = goal - self.list = list_to_sort - self.max_tokens = max_tokens - self.temperature = temperature - self.log_explanations = log_explanations + """ + A class to sort items in a list based on a given goal using the OpenAI API. + + Attributes: + goal (str): The goal for sorting the list. + list (List[str]): The list of items to sort. + max_tokens (int): The maximum number of tokens to generate. + temperature (float): Sampling temperature for the OpenAI model. + log_explanations (bool): Whether to log explanations of sorting decisions. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + sort(): Sorts the entire list of items. + batch_compare(pairs): Sends multiple comparison pairs to the API in one request to reduce API calls. + merge_sort(items): Recursively sorts the items using merge sort. + merge(left, right): Merges two sorted lists into one. + """ + + def __init__(self, data: SortListInput): + """ + Constructs all the necessary attributes for the SortListAgent object. + + Args: + data (SortListInput): An instance of SortListInput containing + the goal, list of items, max_tokens, temperature, and log_explanations. + """ + self.goal = data.goal + self.list = data.list_to_sort + self.max_tokens = data.max_tokens + self.temperature = data.temperature + self.log_explanations = data.log_explanations self.openai_client = OpenAIClient() async def sort(self): + """ + Sorts the entire list based on the provided items and goal. + + Returns: + List[str]: The sorted list of items. + """ return await self.merge_sort(self.list) async def batch_compare(self, pairs): """ - Send multiple comparison pairs to the API in one request to reduce API calls. + Sends multiple comparison pairs to the API in one request to reduce API calls. + + Args: + pairs (List[Tuple[str, str]]): A list of pairs of items to compare. + + Returns: + List[str]: A list of results for each comparison. """ batch_prompt = "\n".join([f"Compare {a} and {b} and return the items in the correct order as 'item1,item2'." for a, b in pairs]) system_prompt = f"You are tasked with sorting items. Goal: {self.goal}.\nCompare the following pairs and return the correct order." - # Log the request we're sending - self.openai_client.logger.info(f"Sending batch comparison request with prompt: {batch_prompt}") - + if self.log_explanations: + self.openai_client.logger.info(f"Sending batch comparison request with prompt: {batch_prompt}") + response = await self.openai_client.complete_chat([ {"role": "system", "content": system_prompt}, {"role": "user", "content": batch_prompt} ], max_tokens=self.max_tokens) - - # Log the response we receive - self.openai_client.logger.info(f"Received response: {response}") - comparisons = response.split("\n") # Assuming API returns comparisons in batch order - - # Check for empty response and log an error + if self.log_explanations: + self.openai_client.logger.info(f"Received response: {response}") + + comparisons = response.split("\n") + if not comparisons: self.openai_client.logger.error("Empty response received from API.") - - # Parse responses and filter out empty or malformed comparisons + parsed_comparisons = [] for comparison in comparisons: - individual_comparisons = comparison.split(" ") # Split individual results + individual_comparisons = comparison.split(" ") for comp in individual_comparisons: comp = comp.strip() - if not comp: # Ignore empty results + if not comp: continue try: first, second = comp.split(",") @@ -54,10 +97,19 @@ async def batch_compare(self, pairs): parsed_comparisons.append("AFTER") except ValueError: self.openai_client.logger.info(f"Ignoring unexpected comparison result: {comp}") - + return parsed_comparisons async def merge_sort(self, items): + """ + Recursively sorts the items using merge sort. + + Args: + items (List[str]): The list of items to sort. + + Returns: + List[str]: The sorted list of items. + """ if len(items) < 2: return items @@ -66,6 +118,16 @@ async def merge_sort(self, items): return await self.merge(left_half, right_half) async def merge(self, left, right): + """ + Merges two sorted lists into one. + + Args: + left (List[str]): The left half of the list. + right (List[str]): The right half of the list. + + Returns: + List[str]: The merged and sorted list of items. + """ result = [] i, j = 0, 0 comparisons_to_make = [] @@ -75,10 +137,8 @@ async def merge(self, left, right): i += 1 j += 1 - # Batch process comparisons comparison_results = await self.batch_compare(comparisons_to_make) - # Safely ignore last comparison if there are no more results if not comparison_results: self.openai_client.logger.info("Final comparison complete.") result.extend(left[i:]) diff --git a/src/core/summarize_list_agent.py b/src/core/summarize_list_agent.py index c39cda3..df6c438 100644 --- a/src/core/summarize_list_agent.py +++ b/src/core/summarize_list_agent.py @@ -1,14 +1,45 @@ -import asyncio +import asyncio # <-- Import asyncio here +from pydantic import BaseModel, Field from typing import List, Dict from .openai_api import OpenAIClient +class SummarizeListInput(BaseModel): + list_to_summarize: List[str] = Field(..., description="The list of items to summarize") + max_tokens: int = Field(1000, description="The maximum number of tokens to generate") + class SummarizeListAgent: - def __init__(self, list_to_summarize: List[str], max_tokens: int = 1000): - self.list_to_summarize = list_to_summarize - self.max_tokens = max_tokens + """ + A class to summarize items in a list using the OpenAI API. + + Attributes: + list_to_summarize (List[str]): The list of items to summarize. + max_tokens (int): The maximum number of tokens to generate. + openai_client (OpenAIClient): An instance of OpenAIClient to interact with the API. + + Methods: + summarize_list(): Summarizes the entire list of items. + summarize_item(): Summarizes a single item. + """ + + def __init__(self, data: SummarizeListInput): + """ + Constructs all the necessary attributes for the SummarizeListAgent object. + + Args: + data (SummarizeListInput): An instance of SummarizeListInput containing + the list of items and max_tokens. + """ + self.list_to_summarize = data.list_to_summarize + self.max_tokens = data.max_tokens self.openai_client = OpenAIClient() async def summarize_list(self) -> List[Dict]: + """ + Summarizes the entire list based on the provided items. + + Returns: + List[Dict]: A list of dictionaries with the original items and their summaries. + """ tasks = [] for item in self.list_to_summarize: user_prompt = f"Summarize the following: {item}." @@ -18,6 +49,15 @@ async def summarize_list(self) -> List[Dict]: return results async def summarize_item(self, user_prompt: str) -> Dict: + """ + Summarizes a single item. + + Args: + user_prompt (str): The prompt to send to the OpenAI API. + + Returns: + Dict: A dictionary with the original item and its summary. + """ system_prompt = "You are an assistant tasked with summarizing items." response = await self.openai_client.complete_chat([ {"role": "system", "content": system_prompt}, diff --git a/src/core/token_counter.py b/src/core/token_counter.py index 4eb8cad..9e4d01d 100644 --- a/src/core/token_counter.py +++ b/src/core/token_counter.py @@ -1,11 +1,36 @@ import tiktoken - class TokenCounter: - def __init__(self, model="gpt-3.5-turbo"): # tiktoken does not provide direct gpt4o-mini support atm - self.encoder = tiktoken.get_encoding("cl100k_base") # so we have to improvise + """ + A class to count the number of tokens in a list of messages using the tiktoken library. + + Attributes: + encoder (tiktoken.Encoding): An instance of the tiktoken encoder to encode messages. + + Methods: + count_tokens(messages): Counts the total number of tokens in a list of messages. + """ + + def __init__(self, model="gpt-3.5-turbo"): + """ + Constructs the TokenCounter object and initializes the encoder. + + Args: + model (str): The model name to use for token encoding. + Note: tiktoken does not directly support "gpt-4o-mini" at the moment. + """ + self.encoder = tiktoken.get_encoding("cl100k_base") def count_tokens(self, messages): + """ + Counts the total number of tokens in a list of messages. + + Args: + messages (list): A list of message dicts containing the "content" key. + + Returns: + int: The total number of tokens in all the messages. + """ total_tokens = 0 for message in messages: total_tokens += len(self.encoder.encode(message["content"]))