Skip to content

Commit

Permalink
Add embeddings gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
tygern committed Jun 9, 2024
1 parent bda85df commit cda00f5
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 1 deletion.
2 changes: 1 addition & 1 deletion databases/versions/b843cebc9fc7_initial_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def upgrade() -> None:
(
id uuid primary key default gen_random_uuid(),
chunk_id uuid references chunks (id),
embedding vector(3072) not null,
embedding vector(1536) not null,
created_at timestamp with time zone not null default now()
);
grant all privileges on table embeddings to ai_starter;
Expand Down
Empty file added starter/search/__init__.py
Empty file.
39 changes: 39 additions & 0 deletions starter/search/embeddings_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List, Optional
from uuid import UUID

from sqlalchemy import Connection

from starter.database_support.database_template import DatabaseTemplate
from starter.database_support.result_mapping import map_results, map_one_result


class EmbeddingsGateway:
def __init__(self, template: DatabaseTemplate):
self.template = template

def create(self, chunk_id: UUID, vector: List[float], connection: Optional[Connection] = None) -> UUID:
result = self.template.query(
"insert into embeddings (chunk_id, embedding) values (:chunk_id, :vector) returning id",
connection,
chunk_id=chunk_id,
vector=vector,
)

return map_one_result(result, lambda row: row["id"])

def unprocessed_chunk_ids(self, connection: Optional[Connection] = None) -> List[str]:
result = self.template.query("""
select chunks.id from chunks
left join public.embeddings e on chunks.id = e.chunk_id
where e.id is null""", connection)

return map_results(result, lambda row: row["id"])

def find_similar_chunk_id(self, vector: List[float], connection: Optional[Connection] = None) -> UUID:
result = self.template.query(
"""select e.chunk_id from embeddings e order by e.embedding <=> :vector limit 1""",
connection,
vector=vector,
)

return map_one_result(result, lambda row: row["chunk_id"])
11 changes: 11 additions & 0 deletions tests/embeddings_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List


def embedding_vector(one_index: int) -> List[float]:
vector = [0] * 1536
vector[one_index] = 1
return vector


def vector_to_string(vector: List[float]) -> str:
return ",".join([str(v) for v in vector])
Empty file added tests/search/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions tests/search/test_embeddings_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import unittest

from starter.documents.chunks_gateway import ChunksGateway
from starter.documents.documents_gateway import DocumentsGateway
from starter.search.embeddings_gateway import EmbeddingsGateway
from tests.db_test_support import TestDatabaseTemplate
from tests.embeddings_support import embedding_vector, vector_to_string


class TestEmbeddingsGateway(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.db = TestDatabaseTemplate()
self.db.clear()

self.documents_gateway = DocumentsGateway(self.db)
self.chunks_gateway = ChunksGateway(self.db)
self.gateway = EmbeddingsGateway(self.db)

def test_create(self):
document_id = self.documents_gateway.create("https://example.com", "some_content")
chunk_id = self.chunks_gateway.create(document_id, "some_content")
vector = embedding_vector(0)

id = self.gateway.create(chunk_id, vector)

result = self.db.query_to_dict("select id, chunk_id, embedding from embeddings")
self.assertEqual([{
"id": id,
"chunk_id": chunk_id,
"embedding": '[' + vector_to_string(vector) + ']',
}], result)

def test_unprocessed_chunk_ids(self):
document_id = self.documents_gateway.create("https://example.com", "some_content")
chunk_id_1 = self.chunks_gateway.create(document_id, "some_content_1")
chunk_id_2 = self.chunks_gateway.create(document_id, "some_content_1")
self.gateway.create(chunk_id_1, embedding_vector(0))

ids = self.gateway.unprocessed_chunk_ids()

self.assertEqual([chunk_id_2], ids)

def find_similar_chunk_id(self):
document_id = self.documents_gateway.create("https://example.com", "some_content")
chunk_id_1 = self.chunks_gateway.create(document_id, "some_content_1")
chunk_id_2 = self.chunks_gateway.create(document_id, "some_content_1")
self.gateway.create(chunk_id_1, embedding_vector(1))
self.gateway.create(chunk_id_2, embedding_vector(2))

self.assertEqual(chunk_id_1, self.gateway.find_similar_chunk_id(embedding_vector(1)))
self.assertEqual(chunk_id_2, self.gateway.find_similar_chunk_id(embedding_vector(2)))

0 comments on commit cda00f5

Please sign in to comment.