From b6d52716c88ac435a9d3b83b474fc75165c19485 Mon Sep 17 00:00:00 2001 From: Tyson Gern Date: Sun, 9 Jun 2024 17:09:19 -0600 Subject: [PATCH] Add web UI --- analyze.py | 3 +- starter/ai/open_ai_client.py | 31 ++++- starter/app.py | 20 +++ starter/database_support/database_template.py | 2 +- starter/documents/documents_gateway.py | 13 ++ starter/index_page.py | 26 ++++ starter/query/__init__.py | 0 starter/query/query_service.py | 31 +++++ starter/search/chunks_search_service.py | 35 ++++++ starter/search/embeddings_gateway.py | 3 +- starter/search/vector_support.py | 5 + starter/static/images/expand-black.svg | 3 + starter/static/images/expand-white.svg | 3 + starter/static/images/favicon.ico | Bin 0 -> 9662 bytes starter/static/images/favicon.svg | 13 ++ starter/static/images/ic-logo-white.svg | 3 + starter/static/images/icons.svg | 5 + starter/static/images/loader-white.svg | 37 ++++++ starter/static/images/search.svg | 3 + starter/static/style/application.css | 7 ++ starter/static/style/buttons.css | 48 ++++++++ starter/static/style/forms.css | 70 +++++++++++ starter/static/style/layout.css | 114 ++++++++++++++++++ starter/static/style/reset.css | 50 ++++++++ starter/static/style/response.css | 8 ++ starter/static/style/theme.css | 52 ++++++++ starter/static/style/typography.css | 60 +++++++++ starter/templates/index.html | 23 ++++ starter/templates/layout.html | 48 ++++++++ starter/templates/response.html | 35 ++++++ tests/ai/test_open_ai_client.py | 29 +++-- tests/chat_support.py | 16 +++ tests/documents/test_documents_gateway.py | 13 +- tests/embeddings_support.py | 10 +- tests/query/__init__.py | 0 tests/query/test_query_service.py | 54 +++++++++ tests/search/test_chunks_search_service.py | 51 ++++++++ tests/search/test_embeddings_analyzer.py | 11 +- tests/search/test_embeddings_gateway.py | 5 +- 39 files changed, 910 insertions(+), 30 deletions(-) create mode 100644 starter/index_page.py create mode 100644 starter/query/__init__.py create mode 100644 starter/query/query_service.py create mode 100644 starter/search/chunks_search_service.py create mode 100644 starter/search/vector_support.py create mode 100644 starter/static/images/expand-black.svg create mode 100644 starter/static/images/expand-white.svg create mode 100644 starter/static/images/favicon.ico create mode 100644 starter/static/images/favicon.svg create mode 100644 starter/static/images/ic-logo-white.svg create mode 100644 starter/static/images/icons.svg create mode 100644 starter/static/images/loader-white.svg create mode 100644 starter/static/images/search.svg create mode 100644 starter/static/style/application.css create mode 100644 starter/static/style/buttons.css create mode 100644 starter/static/style/forms.css create mode 100644 starter/static/style/layout.css create mode 100644 starter/static/style/reset.css create mode 100644 starter/static/style/response.css create mode 100644 starter/static/style/theme.css create mode 100644 starter/static/style/typography.css create mode 100644 starter/templates/index.html create mode 100644 starter/templates/layout.html create mode 100644 starter/templates/response.html create mode 100644 tests/chat_support.py create mode 100644 tests/query/__init__.py create mode 100644 tests/query/test_query_service.py create mode 100644 tests/search/test_chunks_search_service.py diff --git a/analyze.py b/analyze.py index 4871e6a..5d1a301 100644 --- a/analyze.py +++ b/analyze.py @@ -17,7 +17,8 @@ ai_client = OpenAIClient( base_url="https://api.openai.com/v1/", api_key=env.open_ai_key, - model="text-embedding-3-small", + embeddings_model="text-embedding-3-small", + chat_model="gpt-4o" ) analyzer = EmbeddingsAnalyzer(embeddings_gateway, chunks_gateway, ai_client) diff --git a/starter/ai/open_ai_client.py b/starter/ai/open_ai_client.py index fb7d83c..9749a09 100644 --- a/starter/ai/open_ai_client.py +++ b/starter/ai/open_ai_client.py @@ -1,13 +1,21 @@ +from dataclasses import dataclass from typing import List import requests +@dataclass +class ChatMessage: + role: str + content: str + + class OpenAIClient: - def __init__(self, base_url: str, api_key: str, model: str): + def __init__(self, base_url: str, api_key: str, embeddings_model: str, chat_model: str): self.base_url = base_url self.api_key = api_key - self.model = model + self.embeddings_model = embeddings_model + self.chat_model = chat_model def fetch_embedding(self, text) -> List[float]: result = requests.post( @@ -17,10 +25,27 @@ def fetch_embedding(self, text) -> List[float]: "Content-Type": "application/json", }, json={ - "model": self.model, + "model": self.embeddings_model, "input": text, "encoding_format": "float", }, ) return result.json()["data"][0]["embedding"] + + def fetch_chat_completion(self, messages: List[ChatMessage]) -> str: + result = requests.post( + f"{self.base_url}/chat/completions", + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self.chat_model, + "messages": [ + {"role": message.role, "content": message.content} + for message in messages + ]}, + ) + + return result.json()["choices"][0]["message"]["content"] diff --git a/starter/app.py b/starter/app.py index 876738a..d77e9d3 100644 --- a/starter/app.py +++ b/starter/app.py @@ -3,9 +3,16 @@ import sqlalchemy from flask import Flask +from starter.ai.open_ai_client import OpenAIClient from starter.database_support.database_template import DatabaseTemplate +from starter.documents.chunks_gateway import ChunksGateway +from starter.documents.documents_gateway import DocumentsGateway from starter.environment import Environment from starter.health_api import health_api +from starter.index_page import index_page +from starter.query.query_service import QueryService +from starter.search.chunks_search_service import ChunksSearchService +from starter.search.embeddings_gateway import EmbeddingsGateway logger = logging.getLogger(__name__) @@ -17,6 +24,19 @@ def create_app(env: Environment = Environment.from_env()) -> Flask: db = sqlalchemy.create_engine(env.database_url, pool_size=4) db_template = DatabaseTemplate(db) + documents_gateway = DocumentsGateway(db_template) + chunks_gateway = ChunksGateway(db_template) + embeddings_gateway = EmbeddingsGateway(db_template) + ai_client = OpenAIClient( + base_url="https://api.openai.com/v1/", + api_key=env.open_ai_key, + embeddings_model="text-embedding-3-small", + chat_model="gpt-4o" + ) + chunks_search_service = ChunksSearchService(embeddings_gateway, chunks_gateway, documents_gateway, ai_client) + + query_service = QueryService(chunks_search_service, ai_client) + app.register_blueprint(index_page(query_service)) app.register_blueprint(health_api(db_template)) return app diff --git a/starter/database_support/database_template.py b/starter/database_support/database_template.py index 9916ed9..0335d93 100644 --- a/starter/database_support/database_template.py +++ b/starter/database_support/database_template.py @@ -1,4 +1,4 @@ -from typing import Optional, Any, Callable, TypeVar, Generic +from typing import Optional, Any, Callable, TypeVar import sqlalchemy from sqlalchemy import Engine, Connection, CursorResult diff --git a/starter/documents/documents_gateway.py b/starter/documents/documents_gateway.py index 0237416..4a38ca2 100644 --- a/starter/documents/documents_gateway.py +++ b/starter/documents/documents_gateway.py @@ -37,3 +37,16 @@ def exists(self, source: str, connection: Optional[Connection] = None) -> bool: ) return map_one_result(result, lambda row: row["count"] > 0) + + def find(self, id: UUID, connection: Optional[Connection] = None) -> DocumentRecord: + result = self.template.query( + "select id, source, content from documents where id = :id", + connection, + id=id, + ) + + return map_one_result(result, lambda row: DocumentRecord( + id=row["id"], + source=row["source"], + content=row["content"], + )) diff --git a/starter/index_page.py b/starter/index_page.py new file mode 100644 index 0000000..c02ddf3 --- /dev/null +++ b/starter/index_page.py @@ -0,0 +1,26 @@ +from flask import Blueprint, render_template, request +from flask.typing import ResponseReturnValue + +from starter.query.query_service import QueryService + + +def index_page(query_service: QueryService) -> Blueprint: + page = Blueprint('index_page', __name__) + + @page.get('/') + def index() -> ResponseReturnValue: + return render_template('index.html') + + @page.post('/') + def query() -> ResponseReturnValue: + user_query = request.form.get('query') + result = query_service.fetch_response(user_query) + + return render_template( + 'response.html', + query=user_query, + source=result.source, + response=result.response, + ) + + return page diff --git a/starter/query/__init__.py b/starter/query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/starter/query/query_service.py b/starter/query/query_service.py new file mode 100644 index 0000000..1a0356c --- /dev/null +++ b/starter/query/query_service.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +from starter.ai.open_ai_client import OpenAIClient, ChatMessage +from starter.search.chunks_search_service import ChunksSearchService + + +@dataclass +class QueryResult: + source: str + response: str + + +class QueryService: + def __init__(self, chunks_search_service: ChunksSearchService, ai_client: OpenAIClient): + self.chunks_search_service = chunks_search_service + self.ai_client = ai_client + + def fetch_response(self, query: str) -> QueryResult: + chunk = self.chunks_search_service.search_for_relevant_chunk(query) + response = self.ai_client.fetch_chat_completion([ + ChatMessage(role="system", content="You are a reporter for a major world newspaper."), + ChatMessage(role="system", content="Write your response as if you were writing a short, high-quality news" + "article for your paper. Limit your response to one paragraph."), + ChatMessage(role="system", content=f"Use the following article for context: {chunk.content}"), + ChatMessage(role="user", content=query), + ]) + + return QueryResult( + source=chunk.source, + response=response + ) diff --git a/starter/search/chunks_search_service.py b/starter/search/chunks_search_service.py new file mode 100644 index 0000000..e9f4ff2 --- /dev/null +++ b/starter/search/chunks_search_service.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + +from starter.ai.open_ai_client import OpenAIClient +from starter.documents.chunks_gateway import ChunksGateway +from starter.documents.documents_gateway import DocumentsGateway +from starter.search.embeddings_gateway import EmbeddingsGateway + + +@dataclass +class ChunkSearchResult: + content: str + source: str + + +class ChunksSearchService: + def __init__(self, + embeddings_gateway: EmbeddingsGateway, + chunks_gateway: ChunksGateway, + documents_gateway: DocumentsGateway, + open_ai_client: OpenAIClient): + self.embeddings_gateway = embeddings_gateway + self.chunks_gateway = chunks_gateway + self.documents_gateway = documents_gateway + self.open_ai_client = open_ai_client + + def search_for_relevant_chunk(self, query: str) -> ChunkSearchResult: + vector = self.open_ai_client.fetch_embedding(query) + chunk_id = self.embeddings_gateway.find_similar_chunk_id(vector) + chunk = self.chunks_gateway.find(chunk_id) + document = self.documents_gateway.find(chunk.document_id) + + return ChunkSearchResult( + content=chunk.content, + source=document.source, + ) diff --git a/starter/search/embeddings_gateway.py b/starter/search/embeddings_gateway.py index 7f44052..504d099 100644 --- a/starter/search/embeddings_gateway.py +++ b/starter/search/embeddings_gateway.py @@ -5,6 +5,7 @@ from starter.database_support.database_template import DatabaseTemplate from starter.database_support.result_mapping import map_results, map_one_result +from starter.search.vector_support import vector_to_string class EmbeddingsGateway: @@ -33,7 +34,7 @@ def find_similar_chunk_id(self, vector: List[float], connection: Optional[Connec result = self.template.query( """select e.chunk_id from embeddings e order by e.embedding <=> :vector limit 1""", connection, - vector=vector, + vector=vector_to_string(vector), ) return map_one_result(result, lambda row: row["chunk_id"]) diff --git a/starter/search/vector_support.py b/starter/search/vector_support.py new file mode 100644 index 0000000..a2590bc --- /dev/null +++ b/starter/search/vector_support.py @@ -0,0 +1,5 @@ +from typing import List + + +def vector_to_string(vector: List[float]) -> str: + return "[" + ",".join([str(v) for v in vector]) + "]" diff --git a/starter/static/images/expand-black.svg b/starter/static/images/expand-black.svg new file mode 100644 index 0000000..12c3496 --- /dev/null +++ b/starter/static/images/expand-black.svg @@ -0,0 +1,3 @@ + + + diff --git a/starter/static/images/expand-white.svg b/starter/static/images/expand-white.svg new file mode 100644 index 0000000..4735db2 --- /dev/null +++ b/starter/static/images/expand-white.svg @@ -0,0 +1,3 @@ + + + diff --git a/starter/static/images/favicon.ico b/starter/static/images/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..db93a49f45962339a183847bbec78e8fb1af12f2 GIT binary patch literal 9662 zcmeI2!HN?>5I}qJD8Yl^!CX8j13ODQ{(U5^>Xf|D9mXbXRwBL zl!#qy!q2~L@aNbh_sUY9GY6iV^9${P|!oc++mTKi2E@ueDn38zSe8M&on8-+#-Pn*R2+X|-BQ zwgrteo6S#YeF<~5TK%EvOWIT_m7m>i_r2CP$?+}x_%Er?HRD^kg8LbW2EqZBX48J92U=7!Sb=_UH zFU0V(muh`TyD%TJj<827|AODo>-Sjxv~>_$ALD=P8VLO#%bBLfbC%YZ62CluX?189 zu7Ad(AB`|h8n@KP@LP1?9}b7FWjsslkgG&5t+cR+|A|41|3!~jjL}PDr#^-sJxdJW zNA^7H^HR9((rk$3r*Dfdt!-tmPve&QSbp>^&rTM8nJ;17S7za-U&}hu*Z@0B`^JqeeG%XlCz#$LqU|6Rhej@6Kr^Di83Id?&M!?{z+qdz5=*PnO~oT3`| L^MY&dXYu + + + diff --git a/starter/static/images/ic-logo-white.svg b/starter/static/images/ic-logo-white.svg new file mode 100644 index 0000000..8046e18 --- /dev/null +++ b/starter/static/images/ic-logo-white.svg @@ -0,0 +1,3 @@ + + + diff --git a/starter/static/images/icons.svg b/starter/static/images/icons.svg new file mode 100644 index 0000000..02bfdfb --- /dev/null +++ b/starter/static/images/icons.svg @@ -0,0 +1,5 @@ + + + diff --git a/starter/static/images/loader-white.svg b/starter/static/images/loader-white.svg new file mode 100644 index 0000000..ff59eca --- /dev/null +++ b/starter/static/images/loader-white.svg @@ -0,0 +1,37 @@ + + + + + + diff --git a/starter/static/images/search.svg b/starter/static/images/search.svg new file mode 100644 index 0000000..6361650 --- /dev/null +++ b/starter/static/images/search.svg @@ -0,0 +1,3 @@ + + + diff --git a/starter/static/style/application.css b/starter/static/style/application.css new file mode 100644 index 0000000..d215a12 --- /dev/null +++ b/starter/static/style/application.css @@ -0,0 +1,7 @@ +@import "reset.css"; +@import "theme.css"; +@import "typography.css"; +@import "layout.css"; +@import "forms.css"; +@import "buttons.css"; +@import "response.css"; diff --git a/starter/static/style/buttons.css b/starter/static/style/buttons.css new file mode 100644 index 0000000..12a355c --- /dev/null +++ b/starter/static/style/buttons.css @@ -0,0 +1,48 @@ +button, .button { + --background: var(--button-background-color); + --text-color: var(--button-text-color); + --border: var(--button-background-color); + + display: flex; + align-items: center; + gap: .5rem; + + height: 2.75rem; + width: fit-content; + padding: 0 1rem; + border: 1px solid var(--border); + background-color: var(--background); + + cursor: pointer; + color: var(--text-color); + font-size: 1rem; + font-weight: 500; + text-transform: uppercase; +} + +button svg, .button svg { + --icon-color: var(--text-color); + + width: 1.5rem; + height: 1.5rem; + margin-left: -.25rem; + flex-shrink: 0; +} + +a.button { + color: var(--text-color); + text-decoration: none; +} + +button.loader { + padding-left: 3rem; + background-image: url("../images/search.svg"); + background-repeat: no-repeat; + background-position-x: .75rem; + background-position-y: center; + background-size: 1.5rem 1.5rem; + + &[disabled] { + background-image: url("../images/loader-white.svg"); + } +} diff --git a/starter/static/style/forms.css b/starter/static/style/forms.css new file mode 100644 index 0000000..b7294e1 --- /dev/null +++ b/starter/static/style/forms.css @@ -0,0 +1,70 @@ +form { + display: flex; + flex-direction: column; + gap: 1rem; +} + +fieldset { + display: flex; + gap: 1rem; + flex-wrap: wrap; +} + +fieldset button, fieldset .button { + align-self: flex-end; +} + +label { + flex: 1; + width: 14rem; + + display: flex; + flex-direction: column; + gap: .5rem; + + cursor: pointer; +} + +label { + font-weight: 600; + text-transform: uppercase; + font-size: .9rem; +} + +label.checkbox, label.radio, label.toggle { + flex-direction: row; + align-items: center; + font-size: 1rem; + text-transform: none; + gap: .75rem; + user-select: none; +} + +input, select { + font-size: 1rem; + text-transform: none; + font-weight: 400; + + padding: .75rem; + border: 1px solid var(--border-color); + color: var(--text-color); + background-color: var(--background-color); + + outline-color: var(--outline-color); +} + +select { + padding-right: 2rem; + -webkit-appearance: none; + -moz-appearance: none; + background-image: var(--expand-image); + background-repeat: no-repeat; + background-position: right .5rem center; +} + +input[type=checkbox], input[type=radio] { + margin: 0; + height: 1.5rem; + width: 1.5rem; + border-color: var(--border-color); +} diff --git a/starter/static/style/layout.css b/starter/static/style/layout.css new file mode 100644 index 0000000..03dbf7c --- /dev/null +++ b/starter/static/style/layout.css @@ -0,0 +1,114 @@ +body { + min-height: 100vh; + position: relative; + + display: grid; + grid-template-columns: 1fr; + grid-template-rows: var(--header-height) 1fr var(--footer-height); + grid-template-areas: 'header' + 'main' + 'footer'; +} + +body > header { + color: var(--header-text-color); + background-color: var(--header-background-color); + grid-area: header; + + display: flex; + justify-content: space-between; + align-items: center; + padding: 0 var(--body-padding); +} + +body > header ul { + display: flex; + align-items: center; + gap: 1rem; +} + +body > header ul:not(:last-child) { + margin: 0; +} + +body > header li { + display: flex; + align-items: center; + margin: 0; + line-height: 1; +} + +body > header svg { + width: 1.5rem; + height: 1.5rem; + flex-shrink: 0; + --icon-color: var(--header-text-color); +} + +body > header svg.logo { + width: 4rem; + height: 3rem; +} + +body > header h1 { + font-weight: 700; + font-size: 1.2rem; + margin: 0; + color: var(--header-text-color); +} + +body > header .button, body > header button { + --background: var(--header-background-color); + --border: var(--header-text-color); +} + +main { + grid-area: main; + min-width: 13rem; +} + +section { + max-width: 50rem; + margin-inline: auto; + padding: var(--body-padding); +} + +section:not(:last-child) { + border-bottom: 1px solid var(--weak-border-color); +} + +footer { + grid-area: footer; + padding: var(--body-padding); + + display: flex; + flex-direction: column; + justify-content: space-between; + + color: var(--inverted-text-color); + background-color: var(--inverted-background-color); + background-image: var(--inverted-ic-logo); + background-repeat: no-repeat; + background-position: right var(--body-padding) center; +} + +footer a, footer a:visited { + color: var(--inverted-text-color); +} + +footer li { + margin-bottom: .5rem; +} + +@media screen and (max-width: 800px) { + body { + grid-template-columns: 1fr; + grid-template-areas: 'header' + 'main' + 'footer'; + } + + body > header .heading { + display: none; + } +} diff --git a/starter/static/style/reset.css b/starter/static/style/reset.css new file mode 100644 index 0000000..9fbef23 --- /dev/null +++ b/starter/static/style/reset.css @@ -0,0 +1,50 @@ +html, body, div, span, applet, object, iframe, +h1, h2, h3, h4, h5, h6, p, blockquote, pre, +a, abbr, acronym, address, big, cite, code, +del, dfn, em, img, ins, kbd, q, s, samp, +small, strike, strong, sub, sup, tt, var, +b, u, i, center, +dl, dt, dd, ol, ul, li, +fieldset, form, label, legend, +table, caption, tbody, tfoot, thead, tr, th, td, +article, aside, canvas, details, embed, +figure, figcaption, footer, header, hgroup, +menu, nav, output, ruby, section, summary, +time, mark, audio, video { + margin: 0; + padding: 0; + border: 0; + font-size: 100%; + font: inherit; + vertical-align: baseline; +} + +/* HTML5 display-role reset for older browsers */ + +article, aside, details, figcaption, figure, +footer, header, hgroup, menu, nav, section { + display: block; +} + +body { + line-height: 1; +} + +ol, ul { + list-style: none; +} + +blockquote, q { + quotes: none; +} + +blockquote:before, blockquote:after, +q:before, q:after { + content: ''; + content: none; +} + +table { + border-collapse: collapse; + border-spacing: 0; +} diff --git a/starter/static/style/response.css b/starter/static/style/response.css new file mode 100644 index 0000000..bfcbd93 --- /dev/null +++ b/starter/static/style/response.css @@ -0,0 +1,8 @@ +.source { + text-wrap: nowrap; + overflow-x: hidden; + text-overflow: ellipsis; + opacity: .7; + font-size: .9rem; + font-style: italic; +} diff --git a/starter/static/style/theme.css b/starter/static/style/theme.css new file mode 100644 index 0000000..3e4603c --- /dev/null +++ b/starter/static/style/theme.css @@ -0,0 +1,52 @@ +:root { + --white: #FBFFFE; + --light-gray: #EDEDED; + --dark-gray: #4A4545; + --black: #080705; + + --primary: #004F71; + --alert: #AA0061; + + --text-color: var(--black); + --link-color: var(--primary); + --outline-color: var(--primary); + --icon-color: var(--black); + --background-color: var(--white); + --hover-background-color: var(--light-gray); + + --inverted-background-color: var(--black); + --inverted-text-color: var(--white); + + --weak-border-color: var(--light-gray); + --border-color: var(--dark-gray); + + --header-background-color: var(--primary); + --header-text-color: var(--white); + + --box-shadow: 0 2px 2px 0 rgba(0, 0, 0, .14), 0 3px 1px -2px rgba(0, 0, 0, .2), 0 1px 5px 0 rgba(0, 0, 0, .12); + + --dropdown-menu-background: var(--white); + + --button-background-color: var(--primary); + --button-border-color: var(--primary); + --button-text-color: var(--white); + + --header-height: 5rem; + --body-padding: 2rem; + --footer-height: 8rem; + --nav-width: 16rem; + --inverted-ic-logo: url("../images/ic-logo-white.svg"); + --expand-image: url("../images/expand-black.svg"); + --white-expand-image: url("../images/expand-white.svg"); +} + +html { + accent-color: var(--primary); +} + +@media screen and (max-width: 800px) { + :root { + --header-height: 4rem; + --body-padding: 1rem; + } +} diff --git a/starter/static/style/typography.css b/starter/static/style/typography.css new file mode 100644 index 0000000..66b46c7 --- /dev/null +++ b/starter/static/style/typography.css @@ -0,0 +1,60 @@ +html { + font-size: 16px; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Noto Sans", Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji"; + color: var(--text-color); + background-color: var(--background-color); +} + +*, :after, :before { + box-sizing: border-box +} + +h1 { + font-size: 2rem; +} + +h2 { + font-size: 1.5rem; +} + +h3 { + font-size: 1.25rem; +} + +h1, h2, h3 { + font-weight: 600; + margin-bottom: 1.5rem; +} + +h1:not(:first-child), h2:not(:first-child), h3:not(:first-child) { + margin-top: 2rem; +} + +p { + line-height: 1.5; +} + +a, a:visited { + color: var(--link-color); +} + +p:not(:last-child), ul:not(:last-child) , ol:not(:last-child) { + margin-bottom: 1rem; +} + +ol.numbered, ul.bulleted { + padding-left: 2rem; + + & li { + line-height: 1.5; + margin-bottom: .5rem; + } +} + +ul.bulleted { + list-style: disc +} + +ol.numbered { + list-style: decimal; +} diff --git a/starter/templates/index.html b/starter/templates/index.html new file mode 100644 index 0000000..35aac95 --- /dev/null +++ b/starter/templates/index.html @@ -0,0 +1,23 @@ +{% extends "layout.html" %} + +{% block body %} +
+

What would you like to know?

+ +

Ask a question about current events in tech.

+ +
+
+ + +
+
+ +
+{% endblock %} diff --git a/starter/templates/layout.html b/starter/templates/layout.html new file mode 100644 index 0000000..d1bc391 --- /dev/null +++ b/starter/templates/layout.html @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + AI Starter + + +
+ +
    +
  • + +
  • +
  • +

    AI Starter

    +
  • +
+
+
+ + +
+ {% block body %}{% endblock %} +
+ + diff --git a/starter/templates/response.html b/starter/templates/response.html new file mode 100644 index 0000000..9c9a3e7 --- /dev/null +++ b/starter/templates/response.html @@ -0,0 +1,35 @@ +{% extends "layout.html" %} + +{% block body %} +
+

What else would you like to know?

+ +

Ask a question about current events in tech.

+ +
+
+ + +
+
+ +
+ +
+

Question

+

{{ query }}

+
+ +
+

Answer

+

{{ source }}

+ +

{{ response }}

+
+{% endblock %} diff --git a/tests/ai/test_open_ai_client.py b/tests/ai/test_open_ai_client.py index a4ec0d1..8d97cf8 100644 --- a/tests/ai/test_open_ai_client.py +++ b/tests/ai/test_open_ai_client.py @@ -2,22 +2,33 @@ import responses -from starter.ai.open_ai_client import OpenAIClient +from starter.ai.open_ai_client import OpenAIClient, ChatMessage +from tests.chat_support import chat_response from tests.embeddings_support import embedding_response, embedding_vector class TestOpenAIClient(unittest.TestCase): + def setUp(self): + self.client = OpenAIClient( + base_url="https://openai.example.com", + api_key="some-key", + embeddings_model="text-embedding-3-small", + chat_model="gpt-4o", + ) + @responses.activate def test_fetch_embedding(self): - responses.add( - responses.POST, - "https://openai.example.com/embeddings", - embedding_response(2), - ) + responses.add(responses.POST, "https://openai.example.com/embeddings", embedding_response(2)) - client = OpenAIClient(base_url="https://openai.example.com", api_key="some-key", model="text-embedding-3-small") + self.assertEqual(embedding_vector(2), self.client.fetch_embedding("some query")) + + @responses.activate + def fetch_chat_completion(self): + responses.add(responses.POST, "https://openai.example.com/chat/completions", chat_response) self.assertEqual( - embedding_vector(2), - client.fetch_embedding("some query"), + "Sounds good to me", + self.client.fetch_chat_completion([ + ChatMessage(role="user", content="Sound good to you?") + ]), ) diff --git a/tests/chat_support.py b/tests/chat_support.py new file mode 100644 index 0000000..8011e58 --- /dev/null +++ b/tests/chat_support.py @@ -0,0 +1,16 @@ +chat_response = """ +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o", + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Sounds good to me" + } + }] +} +""" diff --git a/tests/documents/test_documents_gateway.py b/tests/documents/test_documents_gateway.py index 7c31ce6..1d54f68 100644 --- a/tests/documents/test_documents_gateway.py +++ b/tests/documents/test_documents_gateway.py @@ -1,6 +1,6 @@ from unittest import TestCase -from starter.documents.documents_gateway import DocumentsGateway +from starter.documents.documents_gateway import DocumentsGateway, DocumentRecord from tests.db_test_support import TestDatabaseTemplate @@ -16,7 +16,6 @@ def test_create(self): id = self.gateway.create("https://example.com", "some content") result = self.db.query_to_dict("select id, source, content from documents") - self.assertEqual([{ "id": id, "source": "https://example.com", @@ -28,3 +27,13 @@ def test_exists(self): self.assertTrue(self.gateway.exists("https://example.com")) self.assertFalse(self.gateway.exists("https://not-there.example.com")) + + def test_find(self): + id = self.gateway.create("https://example.com", "some content") + + record = self.gateway.find(id) + + self.assertEqual( + DocumentRecord(id, "https://example.com", "some content"), + record, + ) diff --git a/tests/embeddings_support.py b/tests/embeddings_support.py index afb0675..aa82daf 100644 --- a/tests/embeddings_support.py +++ b/tests/embeddings_support.py @@ -1,16 +1,18 @@ from typing import List +from starter.search.vector_support import vector_to_string + def embedding_vector(one_index: int) -> List[float]: + """ + + :rtype: object + """ vector = [0] * 1536 vector[one_index] = 1 return vector -def vector_to_string(vector: List[float]) -> str: - return "[" + ",".join([str(v) for v in vector]) + "]" - - def embedding_response(one_index: int): return f""" {"{"} diff --git a/tests/query/__init__.py b/tests/query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/query/test_query_service.py b/tests/query/test_query_service.py new file mode 100644 index 0000000..e1ba63d --- /dev/null +++ b/tests/query/test_query_service.py @@ -0,0 +1,54 @@ +import unittest + +import responses + +from starter.ai.open_ai_client import OpenAIClient +from starter.documents.chunks_gateway import ChunksGateway +from starter.documents.documents_gateway import DocumentsGateway +from starter.query.query_service import QueryService, QueryResult +from starter.search.chunks_search_service import ChunksSearchService +from starter.search.embeddings_gateway import EmbeddingsGateway +from tests.chat_support import chat_response +from tests.db_test_support import TestDatabaseTemplate +from tests.embeddings_support import embedding_response, embedding_vector + + +class TestQueryService(unittest.TestCase): + def setUp(self): + super().setUp() + self.db = TestDatabaseTemplate() + self.db.clear() + + self.documents_gateway = DocumentsGateway(self.db) + self.chunks_gateway = ChunksGateway(self.db) + self.embeddings_gateway = EmbeddingsGateway(self.db) + ai_client = OpenAIClient(base_url="https://openai.example.com", api_key="some-key", + embeddings_model="text-embedding-3-small", chat_model="gpt-4o") + + chunks_service = ChunksSearchService( + self.embeddings_gateway, + self.chunks_gateway, + self.documents_gateway, + ai_client + ) + + self.service = QueryService(chunks_service, ai_client) + + @responses.activate + def test_fetch_response(self): + responses.add(responses.POST, "https://openai.example.com/embeddings", embedding_response(2)) + responses.add(responses.POST, "https://openai.example.com/chat/completions", chat_response) + + document_id_1 = self.documents_gateway.create("https://example.com/1", "some_content_1") + document_id_2 = self.documents_gateway.create("https://example.com/2", "some_content_2") + chunk_id_1 = self.chunks_gateway.create(document_id_1, "some_content_1") + chunk_id_2 = self.chunks_gateway.create(document_id_2, "some_content_2") + self.embeddings_gateway.create(chunk_id_1, embedding_vector(1)) + self.embeddings_gateway.create(chunk_id_2, embedding_vector(2)) + + result = self.service.fetch_response("Does that sound good") + + self.assertEqual( + QueryResult(source="https://example.com/2", response="Sounds good to me"), + result, + ) diff --git a/tests/search/test_chunks_search_service.py b/tests/search/test_chunks_search_service.py new file mode 100644 index 0000000..085147e --- /dev/null +++ b/tests/search/test_chunks_search_service.py @@ -0,0 +1,51 @@ +import unittest + +import responses + +from starter.ai.open_ai_client import OpenAIClient +from starter.documents.chunks_gateway import ChunksGateway +from starter.documents.documents_gateway import DocumentsGateway +from starter.search.chunks_search_service import ChunksSearchService, ChunkSearchResult +from starter.search.embeddings_gateway import EmbeddingsGateway +from tests.db_test_support import TestDatabaseTemplate +from tests.embeddings_support import embedding_vector, embedding_response + + +class TestChunksSearchService(unittest.TestCase): + def setUp(self): + super().setUp() + self.db = TestDatabaseTemplate() + self.db.clear() + + self.documents_gateway = DocumentsGateway(self.db) + self.chunks_gateway = ChunksGateway(self.db) + self.embeddings_gateway = EmbeddingsGateway(self.db) + ai_client = OpenAIClient(base_url="https://openai.example.com", api_key="some-key", + embeddings_model="text-embedding-3-small", chat_model="gpt-4o") + + self.service = ChunksSearchService(self.embeddings_gateway, self.chunks_gateway, self.documents_gateway, ai_client) + + @responses.activate + def test_search_for_relevant_chunk(self): + responses.add( + responses.POST, + "https://openai.example.com/embeddings", + embedding_response(2), + ) + + document_id_1 = self.documents_gateway.create("https://example.com/1", "some_content_1") + document_id_2 = self.documents_gateway.create("https://example.com/2", "some_content_2") + chunk_id_1 = self.chunks_gateway.create(document_id_1, "some_content_1") + chunk_id_2 = self.chunks_gateway.create(document_id_2, "some_content_2") + self.embeddings_gateway.create(chunk_id_1, embedding_vector(1)) + self.embeddings_gateway.create(chunk_id_2, embedding_vector(2)) + + result = self.service.search_for_relevant_chunk("some query") + + self.assertEqual( + ChunkSearchResult( + content="some_content_2", + source="https://example.com/2", + ), + result, + ) diff --git a/tests/search/test_embeddings_analyzer.py b/tests/search/test_embeddings_analyzer.py index a8802ed..ada3df8 100644 --- a/tests/search/test_embeddings_analyzer.py +++ b/tests/search/test_embeddings_analyzer.py @@ -7,8 +7,9 @@ from starter.documents.documents_gateway import DocumentsGateway from starter.search.embeddings_analyzer import EmbeddingsAnalyzer from starter.search.embeddings_gateway import EmbeddingsGateway +from starter.search.vector_support import vector_to_string from tests.db_test_support import TestDatabaseTemplate -from tests.embeddings_support import embedding_response, vector_to_string, embedding_vector +from tests.embeddings_support import embedding_response, embedding_vector class TestEmbeddingsAnalyzer(unittest.TestCase): @@ -21,17 +22,13 @@ def setUp(self): self.chunks_gateway = ChunksGateway(self.db) embeddings_gateway = EmbeddingsGateway(self.db) ai_client = OpenAIClient(base_url="https://openai.example.com", api_key="some-key", - model="text-embedding-3-small") + embeddings_model="text-embedding-3-small", chat_model="gpt-4o") self.analyzer = EmbeddingsAnalyzer(embeddings_gateway, self.chunks_gateway, ai_client) @responses.activate def test_analyze(self): - responses.add( - responses.POST, - "https://openai.example.com/embeddings", - embedding_response(2), - ) + responses.add(responses.POST, "https://openai.example.com/embeddings", embedding_response(2)) document_id = self.documents_gateway.create("https://example.com", "some_content") chunk_id_1 = self.chunks_gateway.create(document_id, "some_content_1") chunk_id_2 = self.chunks_gateway.create(document_id, "some_content_1") diff --git a/tests/search/test_embeddings_gateway.py b/tests/search/test_embeddings_gateway.py index 9b0925a..09d71c8 100644 --- a/tests/search/test_embeddings_gateway.py +++ b/tests/search/test_embeddings_gateway.py @@ -3,8 +3,9 @@ from starter.documents.chunks_gateway import ChunksGateway from starter.documents.documents_gateway import DocumentsGateway from starter.search.embeddings_gateway import EmbeddingsGateway +from starter.search.vector_support import vector_to_string from tests.db_test_support import TestDatabaseTemplate -from tests.embeddings_support import embedding_vector, vector_to_string +from tests.embeddings_support import embedding_vector class TestEmbeddingsGateway(unittest.TestCase): @@ -41,7 +42,7 @@ def test_unprocessed_chunk_ids(self): self.assertEqual([chunk_id_2], ids) - def find_similar_chunk_id(self): + def test_find_similar_chunk_id(self): document_id = self.documents_gateway.create("https://example.com", "some_content") chunk_id_1 = self.chunks_gateway.create(document_id, "some_content_1") chunk_id_2 = self.chunks_gateway.create(document_id, "some_content_1")