diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index caf1cc3..f6ff8f2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,9 @@ jobs: pip install -r requirements.txt pip install black flake8 + - name: Run install.py for database setup + run: python install.py # Run this to initialize the database + - name: Run flake8 run: flake8 src/ tests/ | tee flake8-output.log @@ -91,4 +94,3 @@ jobs: # Clean up temporary files rm $ENCODED_LOG_FILE $ENCODED_ERROR_LOG_FILE $ENCODED_TEST_LOG_FILE - diff --git a/.gitignore b/.gitignore index cb4e6e2..b5e79d3 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,6 @@ cython_debug/ venv-py39/ venv-py310/ venv-py311/ +venv-py312/ + +config/settings.json diff --git a/INSTALL.md b/INSTALL.md index 5f42b0b..fb628b1 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -63,9 +63,18 @@ agentm-py/ │ ├── token_counter.py │ ├── concurrency.py │ ├── prompt_generation.py +│ ├── parallel_complete_prompt.py +│ ├── log_complete_prompt.py +│ ├── compose_prompt.py │ └── __init__.py ├── tests/ -│ └── test_example.py +│ ├── test_openai_api.py +│ ├── test_token_counter.py +│ ├── test_prompt_generation.py +│ ├── test_parallel_complete_prompt.py +│ ├── test_compose_prompt.py +│ ├── test_log_complete_prompt.py +│ └── test_database.py ├── src/ │ └── __init__.py ├── docs/ @@ -98,6 +107,9 @@ agentm-py/ │ ├── token_counter.py │ ├── concurrency.py │ ├── prompt_generation.py +│ ├── parallel_complete_prompt.py +│ ├── log_complete_prompt.py +│ ├── compose_prompt.py │ └── __init__.py ├── var/ │ ├── data/ @@ -105,7 +117,13 @@ agentm-py/ │ └── logs/ │ └── error.log ├── tests/ -│ └── test_example.py +│ ├── test_openai_api.py +│ ├── test_token_counter.py +│ ├── test_prompt_generation.py +│ ├── test_parallel_complete_prompt.py +│ ├── test_compose_prompt.py +│ ├── test_log_complete_prompt.py +│ └── test_database.py ├── src/ │ └── __init__.py ├── docs/ diff --git a/install.py b/install.py new file mode 100644 index 0000000..dea3564 --- /dev/null +++ b/install.py @@ -0,0 +1,88 @@ +import sqlite3 +import os +import json + +# Function to create config/settings.json + +def create_settings(): + # Prompt user for settings + api_key = input('Enter your OpenAI API key: ') + tier = input('Enter your OpenAI tier level (e.g., tier-1): ') + log_path = input('Enter the log directory path [default: ./var/logs/error.log]: ') or './var/logs/error.log' + database_path = input('Enter the database path [default: ./var/data/agents.db]: ') or './var/data/agents.db' + + # Save settings to JSON file + settings = { + 'openai_api_key': api_key, + 'tier': tier, + 'log_path': log_path, + 'database_path': database_path + } + os.makedirs('./config', exist_ok=True) + with open('./config/settings.json', 'w') as f: + json.dump(settings, f, indent=4) + print('Settings saved to config/settings.json') + + +# Function to create the database structure + +def create_database(db_path): + os.makedirs(os.path.dirname(db_path), exist_ok=True) + conn = sqlite3.connect(db_path) + c = conn.cursor() + + # Create tables + c.execute('''CREATE TABLE IF NOT EXISTS models ( + id INTEGER PRIMARY KEY, + model TEXT NOT NULL, + price_per_prompt_token REAL NOT NULL, + price_per_completion_token REAL NOT NULL)''') + + c.execute('''CREATE TABLE IF NOT EXISTS rate_limits ( + id INTEGER PRIMARY KEY, + model TEXT NOT NULL, + tier TEXT NOT NULL, + rpm_limit INTEGER NOT NULL, + tpm_limit INTEGER NOT NULL, + rpd_limit INTEGER NOT NULL)''') + + c.execute('''CREATE TABLE IF NOT EXISTS api_usage ( + id INTEGER PRIMARY KEY, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + session_id TEXT NOT NULL, + model TEXT NOT NULL, + prompt_tokens INTEGER NOT NULL, + completion_tokens INTEGER NOT NULL, + total_tokens INTEGER NOT NULL, + price_per_prompt_token REAL NOT NULL, + price_per_completion_token REAL NOT NULL, + total_cost REAL NOT NULL)''') + + c.execute('''CREATE TABLE IF NOT EXISTS chat_sessions ( + id INTEGER PRIMARY KEY, + session_id TEXT NOT NULL, + start_time DATETIME DEFAULT CURRENT_TIMESTAMP, + end_time DATETIME)''') + + c.execute('''CREATE TABLE IF NOT EXISTS chats ( + id INTEGER PRIMARY KEY, + session_id TEXT NOT NULL, + chat_id TEXT NOT NULL, + message TEXT NOT NULL, + role TEXT NOT NULL, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') + + # Insert default models and rate limits + c.execute("INSERT INTO models (model, price_per_prompt_token, price_per_completion_token) VALUES ('gpt-4o-mini', 0.03, 0.06)") + c.execute("INSERT INTO rate_limits (model, tier, rpm_limit, tpm_limit, rpd_limit) VALUES ('gpt-4o-mini', 'tier-1', 60, 50000, 1000)") + + conn.commit() + conn.close() + print(f"Database created at {db_path}") + + +if __name__ == '__main__': + create_settings() + with open('./config/settings.json', 'r') as f: + settings = json.load(f) + create_database(settings['database_path']) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8c63e35 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +pythonpath = src +asyncio_default_fixture_loop_scope = function diff --git a/requirements.txt b/requirements.txt index bfe3582..d86ccb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,6 @@ urllib3==2.2.2 virtualenv==20.26.3 black flake8 +tiktoken +anyio +trio diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..16ff757 --- /dev/null +++ b/setup.py @@ -0,0 +1,8 @@ +from setuptools import setup, find_packages + +setup( + name='agentm-py', + version='0.1', + packages=find_packages(where='src'), + package_dir={'': 'src'}, +) \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/compose_prompt.py b/src/core/compose_prompt.py new file mode 100644 index 0000000..a67abe9 --- /dev/null +++ b/src/core/compose_prompt.py @@ -0,0 +1,4 @@ +import re + +def compose_prompt(template: str, variables: dict) -> str: + return re.sub(r'{{\s*([^}\s]+)\s*}}', lambda match: str(variables.get(match.group(1), '')), template) diff --git a/src/core/concurrency.py b/src/core/concurrency.py new file mode 100644 index 0000000..54dc481 --- /dev/null +++ b/src/core/concurrency.py @@ -0,0 +1,15 @@ +import asyncio + +class Semaphore: + def __init__(self, max_concurrent_tasks): + self.semaphore = asyncio.Semaphore(max_concurrent_tasks) + + async def __aenter__(self): + await self.semaphore.acquire() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.semaphore.release() + + async def call_function(self, func, *args, **kwargs): + async with self.semaphore: + return await func(*args, **kwargs) diff --git a/src/core/database.py b/src/core/database.py new file mode 100644 index 0000000..2538f71 --- /dev/null +++ b/src/core/database.py @@ -0,0 +1,40 @@ +import sqlite3 +from datetime import datetime + +class Database: + def __init__(self, db_path): + self.db_path = db_path + + def connect(self): + return sqlite3.connect(self.db_path) + + def check_rate_limits(self, model): + conn = self.connect() + c = conn.cursor() + + # Check current API usage (RPM, TPM, RPD) + c.execute("SELECT SUM(total_tokens) FROM api_usage WHERE model = ? AND timestamp >= datetime('now', '-1 minute')", (model,)) + tokens_last_minute = c.fetchone()[0] or 0 + + c.execute("SELECT tpm_limit FROM rate_limits WHERE model = ?", (model,)) + tpm_limit = c.fetchone()[0] + + conn.close() + return tokens_last_minute < tpm_limit + + def log_api_usage(self, session_id, model, prompt_tokens, completion_tokens, total_tokens): + conn = self.connect() + c = conn.cursor() + + # Fetch token prices + c.execute("SELECT price_per_prompt_token, price_per_completion_token FROM models WHERE model = ?", (model,)) + prices = c.fetchone() + prompt_price = prices[0] + completion_price = prices[1] + total_cost = (prompt_tokens * prompt_price) + (completion_tokens * completion_price) + + c.execute("INSERT INTO api_usage (session_id, model, prompt_tokens, completion_tokens, total_tokens, price_per_prompt_token, price_per_completion_token, total_cost) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (session_id, model, prompt_tokens, completion_tokens, total_tokens, prompt_price, completion_price, total_cost)) + + conn.commit() + conn.close() diff --git a/src/core/log_complete_prompt.py b/src/core/log_complete_prompt.py new file mode 100644 index 0000000..ef8884b --- /dev/null +++ b/src/core/log_complete_prompt.py @@ -0,0 +1,16 @@ +from core.logging import Logger + +class LogCompletePrompt: + def __init__(self, complete_prompt_func): + self.complete_prompt_func = complete_prompt_func + self.logger = Logger() + + async def complete_prompt(self, *args, **kwargs): + result = await self.complete_prompt_func(*args, **kwargs) + + if result['completed']: + self.logger.info('Prompt completed successfully.') + else: + self.logger.error('Prompt completion failed.') + + return result diff --git a/src/core/logging.py b/src/core/logging.py new file mode 100644 index 0000000..225ca2c --- /dev/null +++ b/src/core/logging.py @@ -0,0 +1,29 @@ +import logging +import json +import os + +class Logger: + def __init__(self, settings_path='../config/settings.json'): + self.settings = self.load_settings(settings_path) + self.log_path = self.settings['log_path'] + os.makedirs(os.path.dirname(self.log_path), exist_ok=True) + logging.basicConfig( + filename=self.log_path, + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' + ) + + def load_settings(self, settings_path): + try: + with open(settings_path, 'r') as f: + return json.load(f) + except FileNotFoundError: + raise Exception(f'Settings file not found at {settings_path}') + except KeyError as e: + raise Exception(f'Missing key in settings: {e}') + + def info(self, message): + logging.info(message) + + def error(self, message): + logging.error(message) diff --git a/src/core/openai_api.py b/src/core/openai_api.py new file mode 100644 index 0000000..c4c33ea --- /dev/null +++ b/src/core/openai_api.py @@ -0,0 +1,47 @@ +import openai +import json +import sqlite3 +from datetime import datetime +from .token_counter import TokenCounter +from .database import Database + +class OpenAIClient: + def __init__(self, settings_path='../config/settings.json'): + settings = self.load_settings(settings_path) + self.api_key = settings['openai_api_key'] + openai.api_key = self.api_key + self.db = Database(settings['database_path']) + self.token_counter = TokenCounter() + + def load_settings(self, settings_path): + try: + with open(settings_path, 'r') as f: + return json.load(f) + except FileNotFoundError: + raise Exception(f'Settings file not found at {settings_path}') + except KeyError as e: + raise Exception(f'Missing key in settings: {e}') + + def complete_chat(self, messages, model='gpt-4o-mini', max_tokens=1500): + # Check rate limits + if not self.db.check_rate_limits(model): + raise Exception(f"Rate limit exceeded for model {model}") + + prompt_tokens = self.token_counter.count_tokens(messages) + + try: + response = openai.ChatCompletion.create( + model=model, + messages=messages, + max_tokens=max_tokens + ) + + completion_tokens = self.token_counter.count_tokens(response.choices[0].message['content']) + total_tokens = prompt_tokens + completion_tokens + + # Log token usage and cost in the database + self.db.log_api_usage('session-1', model, prompt_tokens, completion_tokens, total_tokens) + + return response.choices[0].message['content'] + except openai.error.OpenAIError as e: + raise Exception(f'Error with OpenAI API: {str(e)}') \ No newline at end of file diff --git a/src/core/parallel_complete_prompt.py b/src/core/parallel_complete_prompt.py new file mode 100644 index 0000000..260ddc0 --- /dev/null +++ b/src/core/parallel_complete_prompt.py @@ -0,0 +1,15 @@ +import asyncio +from .concurrency import Semaphore + +class ParallelCompletePrompt: + def __init__(self, complete_prompt_func, parallel_completions=1, should_continue_func=None): + self.complete_prompt_func = complete_prompt_func + self.parallel_completions = parallel_completions + self.should_continue_func = should_continue_func or (lambda: True) + self.semaphore = Semaphore(parallel_completions) + + async def complete_prompt(self, *args, **kwargs): + async with self.semaphore: + if not self.should_continue_func(): + raise asyncio.CancelledError("Operation cancelled.") + return await self.complete_prompt_func(*args, **kwargs) diff --git a/src/core/prompt_generation.py b/src/core/prompt_generation.py new file mode 100644 index 0000000..d255e9b --- /dev/null +++ b/src/core/prompt_generation.py @@ -0,0 +1,9 @@ +class PromptGenerator: + def __init__(self): + self.prompts = [] + + def add_prompt(self, prompt): + self.prompts.append(prompt) + + def generate_combined_prompt(self): + return "\n".join(self.prompts) diff --git a/src/core/token_counter.py b/src/core/token_counter.py new file mode 100644 index 0000000..20b3961 --- /dev/null +++ b/src/core/token_counter.py @@ -0,0 +1,11 @@ +import tiktoken + +class TokenCounter: + def __init__(self, model='gpt-3.5-turbo'): + self.encoder = tiktoken.get_encoding('cl100k_base') + + def count_tokens(self, messages): + total_tokens = 0 + for message in messages: + total_tokens += len(self.encoder.encode(message['content'])) + return total_tokens diff --git a/tests/test_compose_prompt.py b/tests/test_compose_prompt.py new file mode 100644 index 0000000..2beec87 --- /dev/null +++ b/tests/test_compose_prompt.py @@ -0,0 +1,10 @@ +import pytest +from core.compose_prompt import compose_prompt + +def test_compose_prompt(): + template = "Hello, {{name}}!" + variables = {'name': 'John'} + + result = compose_prompt(template, variables) + + assert result == "Hello, John!" diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..cc1fc4b --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,50 @@ +import pytest +from unittest import mock +from core.database import Database + + +def test_check_rate_limits(): + # Mock the database connection and cursor + mock_conn = mock.MagicMock() + mock_cursor = mock_conn.cursor.return_value + + # Mock the cursor's return values for token limits and usage + mock_cursor.fetchone.side_effect = [(100,), (50000,)] # Return tuples for token usage and limit + + # Patch the connect method to return the mock connection + with mock.patch('sqlite3.connect', return_value=mock_conn): + db = Database(db_path=':memory:') # Use an in-memory database for testing + + # Call the method and assert it returns True (since usage < limit) + assert db.check_rate_limits('gpt-4o-mini') == True + + +def test_log_and_delete_api_usage(): + # Mock the database connection and cursor + mock_conn = mock.MagicMock() + mock_cursor = mock_conn.cursor.return_value + + # Patch the connect method to return the mock connection + with mock.patch('sqlite3.connect', return_value=mock_conn): + db = Database(db_path=':memory:') + + # Log the API usage + db.log_api_usage('session-1', 'gpt-4o-mini', 100, 200, 300) + + # Assert the INSERT query was executed + mock_cursor.execute.assert_any_call( + "INSERT INTO api_usage (session_id, model, prompt_tokens, completion_tokens, total_tokens, price_per_prompt_token, price_per_completion_token, total_cost) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ('session-1', 'gpt-4o-mini', 100, 200, 300, mock.ANY, mock.ANY, mock.ANY) + ) + + # Mock the SELECT query to return the inserted values + mock_cursor.fetchone.side_effect = [('session-1', 'gpt-4o-mini', 100, 200, 300)] + + # Simulate fetching the logged API usage + db.connect().cursor().execute("SELECT * FROM api_usage WHERE session_id = ?", ('session-1',)) + + # Now simulate deleting the entry + db.connect().cursor().execute("DELETE FROM api_usage WHERE session_id = ?", ('session-1',)) + + # Assert the DELETE query was executed + mock_cursor.execute.assert_any_call("DELETE FROM api_usage WHERE session_id = ?", ('session-1',)) diff --git a/tests/test_log_complete_prompt.py b/tests/test_log_complete_prompt.py new file mode 100644 index 0000000..3fea884 --- /dev/null +++ b/tests/test_log_complete_prompt.py @@ -0,0 +1,25 @@ +import pytest +import shutil +import os +from core.log_complete_prompt import LogCompletePrompt + +@pytest.mark.asyncio +async def test_logging(): + # Ensure the config folder exists in the test environment + if not os.path.exists('../config'): + os.makedirs('../config') + + # Copy the settings.json file if it doesn't exist in the test environment + if not os.path.exists('../config/settings.json'): + shutil.copyfile('./config/settings.json', '../config/settings.json') + + # Define a mock completion function + async def mock_complete_prompt(*args, **kwargs): + return {'completed': True, 'value': 'Success'} + + # Initialize LogCompletePrompt + log_prompt = LogCompletePrompt(mock_complete_prompt) + result = await log_prompt.complete_prompt() + + # Assert the completion result + assert result['completed'] == True diff --git a/tests/test_parallel_complete_prompt.py b/tests/test_parallel_complete_prompt.py new file mode 100644 index 0000000..32e9d03 --- /dev/null +++ b/tests/test_parallel_complete_prompt.py @@ -0,0 +1,12 @@ +import pytest +from core.parallel_complete_prompt import ParallelCompletePrompt + +@pytest.mark.anyio +async def test_parallel_completion(): + async def mock_complete_prompt(*args, **kwargs): + return {'completed': True, 'value': 'Success'} + + parallel_prompt = ParallelCompletePrompt(mock_complete_prompt) + result = await parallel_prompt.complete_prompt() + + assert result['completed'] == True diff --git a/tests/test_prompt_generation.py b/tests/test_prompt_generation.py new file mode 100644 index 0000000..233a5d7 --- /dev/null +++ b/tests/test_prompt_generation.py @@ -0,0 +1,12 @@ +import pytest +from core.prompt_generation import PromptGenerator + + +def test_prompt_generation(): + generator = PromptGenerator() + generator.add_prompt("This is prompt 1") + generator.add_prompt("This is prompt 2") + + combined = generator.generate_combined_prompt() + + assert combined == "This is prompt 1\nThis is prompt 2" diff --git a/tests/test_token_counter.py b/tests/test_token_counter.py new file mode 100644 index 0000000..abe112d --- /dev/null +++ b/tests/test_token_counter.py @@ -0,0 +1,11 @@ +import pytest +from core.token_counter import TokenCounter + + +def test_token_counting(): + counter = TokenCounter() + + messages = [{'role': 'user', 'content': 'Hello!'}] + token_count = counter.count_tokens(messages) + + assert token_count > 0