From 89b4d97ebbf9181d984adc9d5b3b29f00d3727b7 Mon Sep 17 00:00:00 2001 From: Tyson Gern Date: Tue, 11 Jun 2024 09:33:37 -0500 Subject: [PATCH] Externalize OpenAI base URL --- .devcontainer/devcontainer.json | 3 ++- .env.example | 1 + README.md | 6 +++--- analyze.py | 2 +- starter/ai/open_ai_client.py | 2 ++ starter/app.py | 2 +- starter/environment.py | 2 ++ tests/ai/test_open_ai_client.py | 3 +++ 8 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 4881281..81bde1c 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -16,6 +16,7 @@ "USE_FLASK_DEBUG_MODE": "true", "FEEDS": "https://feed.infoq.com/development/,https://blog.jetbrains.com/feed/,https://feed.infoq.com/Devops/,https://feed.infoq.com/architecture-design/", "ROOT_LOG_LEVEL": "INFO", - "STARTER_LOG_LEVEL": "DEBUG" + "STARTER_LOG_LEVEL": "DEBUG", + "OPEN_AI_BASE_URL": "https://api.openai.com/v1/" } } diff --git a/.env.example b/.env.example index dfc7fa0..eac293a 100644 --- a/.env.example +++ b/.env.example @@ -1 +1,2 @@ export OPEN_AI_KEY=fill_me_in +export OPEN_AI_BASE_URL=fill_me_in diff --git a/README.md b/README.md index 41faecf..5f65c8f 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ and uses pgvector to store the embeddings in PostgreSQL. The web application collects the user's query and creates an embedding with the OpenAI Embeddings API. It then searches the PostgreSQL for similar embeddings (using pgvector) and provides the corresponding chunk of text as -context for a query to the [Azure AI Chat Completion API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions). +context for a query to the [Azure AI Chat Completion API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions). ## Local development @@ -113,7 +113,7 @@ context for a query to the [Azure AI Chat Completion API](https://learn.microsof [localhost:5001](http://localhost:5001). ```shell - python collector.py - python analyzer.py + python collect.py + python analyze.py python -m starter ``` diff --git a/analyze.py b/analyze.py index 7f3bd93..9300127 100644 --- a/analyze.py +++ b/analyze.py @@ -19,7 +19,7 @@ chunks_gateway = ChunksGateway(db_template) embeddings_gateway = EmbeddingsGateway(db_template) ai_client = OpenAIClient( - base_url="https://api.openai.com/v1/", + base_url=env.open_ai_base_url, api_key=env.open_ai_key, embeddings_model="text-embedding-3-small", chat_model="gpt-4o" diff --git a/starter/ai/open_ai_client.py b/starter/ai/open_ai_client.py index 8ec06ff..c1b4f6f 100644 --- a/starter/ai/open_ai_client.py +++ b/starter/ai/open_ai_client.py @@ -36,6 +36,7 @@ def fetch_embedding(self, text) -> Result[List[float]]: }, ) if not response.ok: + logger.error(f"Received {response.status_code} response from {self.base_url}: {response.text}") return Failure("Failed to fetch embedding") return Success(response.json()["data"][0]["embedding"]) @@ -55,6 +56,7 @@ def fetch_chat_completion(self, messages: List[ChatMessage]) -> Result[str]: ]}, ) if not response.ok: + logger.error(f"Received {response.status_code} response from {self.base_url}: {response.text}") return Failure("Failed to fetch completion") return Success(response.json()["choices"][0]["message"]["content"]) diff --git a/starter/app.py b/starter/app.py index d77e9d3..ec40629 100644 --- a/starter/app.py +++ b/starter/app.py @@ -28,7 +28,7 @@ def create_app(env: Environment = Environment.from_env()) -> Flask: chunks_gateway = ChunksGateway(db_template) embeddings_gateway = EmbeddingsGateway(db_template) ai_client = OpenAIClient( - base_url="https://api.openai.com/v1/", + base_url=env.open_ai_base_url, api_key=env.open_ai_key, embeddings_model="text-embedding-3-small", chat_model="gpt-4o" diff --git a/starter/environment.py b/starter/environment.py index 283fcfb..aecea97 100644 --- a/starter/environment.py +++ b/starter/environment.py @@ -10,6 +10,7 @@ class Environment: use_flask_debug_mode: bool feeds: List[str] open_ai_key: str + open_ai_base_url: str root_log_level: str starter_log_level: str @@ -21,6 +22,7 @@ def from_env(cls) -> 'Environment': use_flask_debug_mode=os.environ.get('USE_FLASK_DEBUG_MODE', 'false') == 'true', feeds=cls.__require_env('FEEDS').strip().split(','), open_ai_key=cls.__require_env('OPEN_AI_KEY'), + open_ai_base_url=cls.__require_env('OPEN_AI_BASE_URL'), root_log_level=os.environ.get('ROOT_LOG_LEVEL', 'INFO'), starter_log_level=os.environ.get('STARTER_LOG_LEVEL', 'INFO'), ) diff --git a/tests/ai/test_open_ai_client.py b/tests/ai/test_open_ai_client.py index bbbae01..6182a37 100644 --- a/tests/ai/test_open_ai_client.py +++ b/tests/ai/test_open_ai_client.py @@ -5,6 +5,7 @@ from starter.ai.open_ai_client import OpenAIClient, ChatMessage from tests.chat_support import chat_response from tests.embeddings_support import embedding_response, embedding_vector +from tests.logging_support import disable_logging class TestOpenAIClient(unittest.TestCase): @@ -22,6 +23,7 @@ def test_fetch_embedding(self): self.assertEqual(embedding_vector(2), self.client.fetch_embedding("some query").value) + @disable_logging @responses.activate def test_fetch_embedding_failure(self): responses.add(responses.POST, "https://openai.example.com/embeddings", "bad news", status=400) @@ -39,6 +41,7 @@ def test_fetch_chat_completion(self): ]).value, ) + @disable_logging @responses.activate def test_fetch_chat_completion_failure(self): responses.add(responses.POST, "https://openai.example.com/chat/completions", "bad news", status=400)