diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 7e28de3a..7756860e 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -74,6 +74,9 @@ class Config: processing: Processing settings such as batch size and concurrency. chunking: The chunking strategy. formatting: Formatting strategy to apply to the chunks. + embedding_batch_schema: The schema where the embedding batches are stored. + embedding_batch_table: The table where the embedding batches are stored. + embedding_batch_chunks_table: The table where the embedding batch chunks are stored. """ version: str @@ -83,6 +86,9 @@ class Config: LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter ) = Field(..., discriminator="implementation") formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation") + embedding_batch_schema: str | None + embedding_batch_table: str | None + embedding_batch_chunks_table: str | None @dataclass @@ -321,6 +327,82 @@ def insert_errors_query(self) -> sql.Composed: self.errors_table_ident, ) + @cached_property + def fetch_batches_to_process_query(self) -> sql.Composed: + return sql.SQL( + "SELECT openai_batch_id, output_file_id FROM {}.{} WHERE status not in('failed', 'processed', 'prepared')" + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def update_batch_embedding_query(self) -> sql.Composed: + return sql.SQL( + "UPDATE {}.{} SET status = %s, completed_at = %s, failed_at = %s, output_file_id = %s, errors = %s WHERE openai_batch_id = %s" + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def update_batch_embedding_status_query(self) -> sql.Composed: + return sql.SQL( + "UPDATE {}.{} SET status = %s WHERE openai_batch_id = %s" + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def fetch_chunks_for_batch_id_query(self) -> sql.Composed: + return sql.SQL( + "SELECT id, text FROM {}.{} WHERE embedding_batch_id = %s", + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_chunks_table, + ) + + @cached_property + def insert_batch_embedding_query(self) -> sql.Composed: + return sql.SQL(""" + INSERT INTO {}.{} ( + openai_batch_id, + input_file_id, + output_file_id, + status, + errors, + expires_at + ) VALUES ( + %s, + %s, + %s, + %s, + %s, + %s + ) + """).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def insert_batch_embedding_chunks_query(self) -> sql.Composed: + return sql.SQL(""" + INSERT INTO {}.{} ( + id, + embedding_batch_id, + text + ) VALUES ( + %s, + %s, + %s + ) + """).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_chunks_table, + ) + def _pks_placeholders_tuples(self, items_count: int) -> sql.Composed: """Generates a comma separated list of tuples with placeholders for the primary key fields of the source table. @@ -510,47 +592,21 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: created_batch, documents = await self._generate_embedding_batch(items) - await conn.execute(""" - INSERT INTO ai.embedding_batches ( - openai_batch_id, - input_file_id, - output_file_id, - status, - errors, - expires_at - ) VALUES ( - %(openai_batch_id)s, - %(input_file_id)s, - %(output_file_id)s, - %(status)s, - %(errors)s, - %(expires_at)s - ) - """, { - 'openai_batch_id': created_batch.id, - 'input_file_id': created_batch.input_file_id, - 'output_file_id': created_batch.output_file_id, - 'status': created_batch.status, - 'errors': created_batch.errors, - 'expires_at': datetime.fromtimestamp(created_batch.expires_at, timezone.utc), - }) + await conn.execute(self.queries.insert_batch_embedding_query, ( + created_batch.id, + created_batch.input_file_id, + created_batch.output_file_id, + created_batch.status, + created_batch.errors, + datetime.fromtimestamp(created_batch.expires_at, timezone.utc), + )) for doc in documents: - await conn.execute(""" - INSERT INTO ai.embedding_batch_chunks ( - id, - embedding_batch_id, - text - ) VALUES ( - %(id)s, - %(embedding_batch_id)s, - %(text)s - ) - """, { - 'id': doc['unique_full_chunk_id'], - 'embedding_batch_id': created_batch.id, - 'text': doc['chunk'] - }) + await conn.execute(self.queries.insert_batch_embedding_chunks_query, ( + doc['unique_full_chunk_id'], + created_batch.id, + doc['chunk'] + )) # TODO how to delete submitted entries from the queue? @@ -574,20 +630,11 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): conn.cursor() as cursor, ): client = openai.OpenAI() # TODO how can I get the client? There has to be one created already that I can use? - await cursor.execute("SELECT openai_batch_id, output_file_id FROM ai.embedding_batches WHERE status not in('failed', 'processed', 'prepared')") + await cursor.execute(self.queries.fetch_batches_to_process_query) for batch_row in await cursor.fetchall(): batch = client.batches.retrieve(batch_row[0]) - await conn.execute(""" - UPDATE ai.embedding_batches - SET - status = %s, - completed_at = %s, - failed_at = %s, - output_file_id = %s, - errors = %s - WHERE embedding_batches.openai_batch_id = %s - """, ( + await conn.execute(self.queries.update_batch_embedding_query, ( batch.status, datetime.fromtimestamp(batch.completed_at, timezone.utc) if batch.completed_at else None, datetime.fromtimestamp(batch.failed_at, timezone.utc) if batch.failed_at else None, @@ -601,11 +648,7 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): if batch.status == 'completed': await self._write_embeddings_from_batch(conn, batch, client) - await cursor.execute(""" - UPDATE ai.embedding_batches - SET status = %s - WHERE openai_batch_id = %s - """, ( + await cursor.execute(self.queries.update_batch_embedding_status_query, ( 'processed', batch_row[0], )) @@ -788,10 +831,9 @@ async def _write_embeddings_from_batch( all_items = [] all_records: list[EmbeddingRecord] = [] - # Fetch all chunks from ai.embedding_batch_chunks where the embedding_batch_id is batch.id async with conn.cursor() as cursor: await cursor.execute( - "SELECT id, text FROM ai.embedding_batch_chunks WHERE embedding_batch_id = %s", + self.queries.fetch_chunks_for_batch_id_query (batch.id,) ) embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()}