Skip to content

Commit

Permalink
feat: move all queries to cached properties
Browse files Browse the repository at this point in the history
  • Loading branch information
kolaente committed Dec 16, 2024
1 parent d49c352 commit cdad4bc
Showing 1 changed file with 99 additions and 57 deletions.
156 changes: 99 additions & 57 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class Config:
processing: Processing settings such as batch size and concurrency.
chunking: The chunking strategy.
formatting: Formatting strategy to apply to the chunks.
embedding_batch_schema: The schema where the embedding batches are stored.
embedding_batch_table: The table where the embedding batches are stored.
embedding_batch_chunks_table: The table where the embedding batch chunks are stored.
"""

version: str
Expand All @@ -83,6 +86,9 @@ class Config:
LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter
) = Field(..., discriminator="implementation")
formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation")
embedding_batch_schema: str | None
embedding_batch_table: str | None
embedding_batch_chunks_table: str | None


@dataclass
Expand Down Expand Up @@ -321,6 +327,82 @@ def insert_errors_query(self) -> sql.Composed:
self.errors_table_ident,
)

@cached_property
def fetch_batches_to_process_query(self) -> sql.Composed:
return sql.SQL(
"SELECT openai_batch_id, output_file_id FROM {}.{} WHERE status not in('failed', 'processed', 'prepared')"
).format(
self.vectorizer.config.embedding_batch_schema,
self.vectorizer.config.embedding_batch_table,
)

@cached_property
def update_batch_embedding_query(self) -> sql.Composed:
return sql.SQL(
"UPDATE {}.{} SET status = %s, completed_at = %s, failed_at = %s, output_file_id = %s, errors = %s WHERE openai_batch_id = %s"
).format(
self.vectorizer.config.embedding_batch_schema,
self.vectorizer.config.embedding_batch_table,
)

@cached_property
def update_batch_embedding_status_query(self) -> sql.Composed:
return sql.SQL(
"UPDATE {}.{} SET status = %s WHERE openai_batch_id = %s"
).format(
self.vectorizer.config.embedding_batch_schema,
self.vectorizer.config.embedding_batch_table,
)

@cached_property
def fetch_chunks_for_batch_id_query(self) -> sql.Composed:
return sql.SQL(
"SELECT id, text FROM {}.{} WHERE embedding_batch_id = %s",
).format(
self.vectorizer.config.embedding_batch_schema,
self.vectorizer.config.embedding_batch_chunks_table,
)

@cached_property
def insert_batch_embedding_query(self) -> sql.Composed:
return sql.SQL("""
INSERT INTO {}.{} (
openai_batch_id,
input_file_id,
output_file_id,
status,
errors,
expires_at
) VALUES (
%s,
%s,
%s,
%s,
%s,
%s
)
""").format(
self.vectorizer.config.embedding_batch_schema,
self.vectorizer.config.embedding_batch_table,
)

@cached_property
def insert_batch_embedding_chunks_query(self) -> sql.Composed:
return sql.SQL("""
INSERT INTO {}.{} (
id,
embedding_batch_id,
text
) VALUES (
%s,
%s,
%s
)
""").format(
self.vectorizer.config.embedding_batch_schema,
self.vectorizer.config.embedding_batch_chunks_table,
)

def _pks_placeholders_tuples(self, items_count: int) -> sql.Composed:
"""Generates a comma separated list of tuples with placeholders for the
primary key fields of the source table.
Expand Down Expand Up @@ -511,47 +593,21 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int:

created_batch, documents = await self._generate_embedding_batch(items)

await conn.execute("""
INSERT INTO ai.embedding_batches (
openai_batch_id,
input_file_id,
output_file_id,
status,
errors,
expires_at
) VALUES (
%(openai_batch_id)s,
%(input_file_id)s,
%(output_file_id)s,
%(status)s,
%(errors)s,
%(expires_at)s
)
""", {
'openai_batch_id': created_batch.id,
'input_file_id': created_batch.input_file_id,
'output_file_id': created_batch.output_file_id,
'status': created_batch.status,
'errors': created_batch.errors,
'expires_at': datetime.fromtimestamp(created_batch.expires_at, timezone.utc),
})
await conn.execute(self.queries.insert_batch_embedding_query, (
created_batch.id,
created_batch.input_file_id,
created_batch.output_file_id,
created_batch.status,
created_batch.errors,
datetime.fromtimestamp(created_batch.expires_at, timezone.utc),
))

for doc in documents:
await conn.execute("""
INSERT INTO ai.embedding_batch_chunks (
id,
embedding_batch_id,
text
) VALUES (
%(id)s,
%(embedding_batch_id)s,
%(text)s
)
""", {
'id': doc['unique_full_chunk_id'],
'embedding_batch_id': created_batch.id,
'text': doc['chunk']
})
await conn.execute(self.queries.insert_batch_embedding_chunks_query, (
doc['unique_full_chunk_id'],
created_batch.id,
doc['chunk']
))

# TODO how to delete submitted entries from the queue?

Expand All @@ -575,20 +631,11 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection):
conn.cursor() as cursor,
):
client = openai.OpenAI() # TODO how can I get the client? There has to be one created already that I can use?
await cursor.execute("SELECT openai_batch_id, output_file_id FROM ai.embedding_batches WHERE status not in('failed', 'processed', 'prepared')")
await cursor.execute(self.queries.fetch_batches_to_process_query)
for batch_row in await cursor.fetchall():
batch = client.batches.retrieve(batch_row[0])

await conn.execute("""
UPDATE ai.embedding_batches
SET
status = %s,
completed_at = %s,
failed_at = %s,
output_file_id = %s,
errors = %s
WHERE embedding_batches.openai_batch_id = %s
""", (
await conn.execute(self.queries.update_batch_embedding_query, (
batch.status,
datetime.fromtimestamp(batch.completed_at, timezone.utc) if batch.completed_at else None,
datetime.fromtimestamp(batch.failed_at, timezone.utc) if batch.failed_at else None,
Expand All @@ -602,11 +649,7 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection):
if batch.status == 'completed':
await self._write_embeddings_from_batch(conn, batch, client)

await cursor.execute("""
UPDATE ai.embedding_batches
SET status = %s
WHERE openai_batch_id = %s
""", (
await cursor.execute(self.queries.update_batch_embedding_status_query, (
'processed',
batch_row[0],
))
Expand Down Expand Up @@ -789,10 +832,9 @@ async def _write_embeddings_from_batch(
all_items = []
all_records: list[EmbeddingRecord] = []

# Fetch all chunks from ai.embedding_batch_chunks where the embedding_batch_id is batch.id
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT id, text FROM ai.embedding_batch_chunks WHERE embedding_batch_id = %s",
self.queries.fetch_chunks_for_batch_id_query
(batch.id,)
)
embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()}
Expand Down

0 comments on commit cdad4bc

Please sign in to comment.