From fff2086977802f5114b22b9605379181ad66c0ca Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 14:51:51 +0100 Subject: [PATCH 01/23] feat: create embedding batches using OpenAI's batch api --- projects/pgai/pgai/vectorizer/embeddings.py | 97 +++++++++++++++++ projects/pgai/pgai/vectorizer/vectorizer.py | 113 +++++++++++++++++++- 2 files changed, 209 insertions(+), 1 deletion(-) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index c7ba5aeeb..e2adb328b 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -2,6 +2,8 @@ import os import re import time +import json +import tempfile from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass @@ -12,6 +14,7 @@ Literal, TypeAlias, TypeVar, + Optional, ) import ollama @@ -36,6 +39,52 @@ logger = structlog.get_logger() +@dataclass +class OpenAIBatch: + """ + Represents detailed information about a batch process. + + Attributes: + id (str): The unique identifier of the batch. + endpoint (str): The OpenAI API endpoint used by the batch. + errors (object): Any errors associated with the batch. + input_file_id (str): The ID of the input file for the batch. + completion_window (str): The time frame within which the batch should be processed. + status (str): The current status of the batch. + output_file_id (str): The ID of the file containing successful request outputs. + error_file_id (str): The ID of the file containing error request outputs. + created_at (int): The Unix timestamp for when the batch was created. + in_progress_at (Optional[int]): The Unix timestamp for when the batch started processing. + expires_at (int): The Unix timestamp for when the batch will expire. + finalizing_at (Optional[int]): The Unix timestamp for when the batch started finalizing. + completed_at (Optional[int]): The Unix timestamp for when the batch was completed. + failed_at (Optional[int]): The Unix timestamp for when the batch failed. + expired_at (Optional[int]): The Unix timestamp for when the batch expired. + cancelling_at (Optional[int]): The Unix timestamp for when the batch started cancelling. + cancelled_at (Optional[int]): The Unix timestamp for when the batch was cancelled. + metadata (Mapping[str, str]): A map of metadata associated with the batch. + """ + + id: str + endpoint: str + errors: object + input_file_id: str + completion_window: str + status: str + output_file_id: str + error_file_id: str + created_at: int + expires_at: int + metadata: Mapping[str, str] + in_progress_at: Optional[int] = None + finalizing_at: Optional[int] = None + completed_at: Optional[int] = None + failed_at: Optional[int] = None + expired_at: Optional[int] = None + cancelling_at: Optional[int] = None + cancelled_at: Optional[int] = None + + @dataclass class ChunkEmbeddingError: """ @@ -405,6 +454,54 @@ async def embed( model_token_length, encoded_documents ) + async def create_and_submit_embedding_batch( + self, documents: list[dict[str, Any]] + ) -> OpenAIBatch: + """ + Creates a batch of embeddings using OpenAI's embeddings API as outlined in + https://platform.openai.com/docs/guides/batch/batch-api?lang=python + + Args: + documents (list[str]): A list of document chunks to be embedded. + + Returns: + + """ + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl', mode='rw') + + for document in documents: + entry = { + 'custom_id': document['id'] + '_' + document['chunk_seq'], + 'method': 'POST', + 'url': '/v1/embeddings', + 'body': { + 'model': 'text-embedding-3-large', # TODO how can I use the configured embeddings model? + 'input': document['chunk'], + }, + } + temp_file.write(json.dumps(entry) + '\n') + + temp_file.close() + + client = openai.OpenAI() # TODO there has to be a client already which I could use instead? + + batch_input_file = client.files.create( + file=open(temp_file.name, "rb"), + purpose="batch", + ) + + batch = client.batches.create( + input_file_id=batch_input_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": "nightly eval job" + } + ) + + return OpenAIBatch(**batch.to_dict()) + async def _filter_by_length_and_embed( self, model_token_length: int, encoded_documents: list[list[int]] ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index d7ea0e5e6..a2ea260e1 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -22,7 +22,7 @@ LangChainCharacterTextSplitter, LangChainRecursiveCharacterTextSplitter, ) -from .embeddings import ChunkEmbeddingError, Ollama, OpenAI, VoyageAI +from .embeddings import ChunkEmbeddingError, Ollama, OpenAI, VoyageAI, OpenAIBatch from .formatting import ChunkValue, PythonTemplate from .processing import ProcessingDefault @@ -465,6 +465,8 @@ async def run(self) -> int: await self.vectorizer.config.embedding.setup() while True: if not self._continue_processing(loops, res): + # TODO how can we run this only after hitting the rate limit of the normal openai batch embedding api? + self._do_openai_batch(conn) return res items_processed = await self._do_batch(conn) if items_processed == 0: @@ -472,6 +474,96 @@ async def run(self) -> int: res += items_processed loops += 1 + @tracer.wrap() + async def _do_openai_batch(self, conn: AsyncConnection) -> int: + """ + Creates embeddings using openai's batch processing api. This allows to process + very large amounts of data faster than with the embeddings api, because the + batch api has vastly higher rate limits. + + Args: + conn (AsyncConnection): The asynchronous database connection. + """ + + # TODO do nothing when openai is not configured + + try: + async with conn.transaction(): + items = await self._fetch_work(conn) + + await logger.adebug(f"Items pulled from queue for openai batch embedding: {len(items)}") + + # Filter out items that were deleted from the source table. + # We use the first primary key column, since they can only + # be null if the LEFT JOIN didn't find a match. + items = [ + i + for i in items + if i[self.vectorizer.source_pk[0].attname] is not None + ] + + if len(items) == 0: + return 0 + + created_batch = await self._generate_embedding_batch(items) + + # TODO this does not feel like the way to go, is there a way to do these kind of migrations properly? + await conn.execute(""" + CREATE TABLE IF NOT EXISTS embedding_batches ( + id BIGSERIAL PRIMARY KEY, + openai_batch_id VARCHAR(255), + input_file_id VARCHAR(255) NOT NULL, + output_file_id VARCHAR(255), + status VARCHAR(255) NOT NULL, + errors JSONB, + created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(), + expires_at TIMESTAMP(0), + completed_at TIMESTAMP(0), + failed_at TIMESTAMP(0) + ); + CREATE INDEX IF NOT EXISTS embedding_batches_status_index + ON embedding_batches (status); + """) + + await conn.execute(""" + INSERT INTO embedding_batches ( + openai_batch_id, + input_file_id, + output_file_id, + status, + errors, + expires_at, + completed_at + ) VALUES ( + %(openai_batch_id)s, + %(input_file_id)s, + %(output_file_id)s, + %(status)s, + %(errors)s, + %(expires_at)s + ) + """, { + 'openai_batch_id': created_batch.id, + 'input_file_id': created_batch.input_file_id, + 'output_file_id': created_batch.output_file_id, + 'status': created_batch.status, + 'errors': created_batch.errors, + 'expires_at': created_batch.expires_at, + }) + + return len(items) + except Exception as e: + async with conn.transaction(): + await self._insert_vectorizer_error( + conn, + ( + self.vectorizer.id, + VECTORIZER_FAILED, + Jsonb({"error_reason": str(e)}), + ), + ) + raise e + @tracer.wrap() async def _do_batch(self, conn: AsyncConnection) -> int: """ @@ -731,6 +823,25 @@ async def _generate_embeddings( records.append(record + [np.array(embedding)]) return records, errors + async def _generate_embedding_batch( + self, items: list[SourceRow] + ) -> OpenAIBatch: + documents: list[dict[str, Any]] = [] + for item in items: + chunks = self.vectorizer.config.chunking.into_chunks(item) + for chunk_id, chunk in enumerate(chunks, 0): + formatted = self.vectorizer.config.formatting.format(chunk, item) + documents.append({ + 'id': item['id'], + 'chunk_id': chunk_id, + 'chunk': formatted, + }) + + try: + return await self.vectorizer.config.embedding.create_and_submit_embedding_batch(documents) + except Exception as e: + raise EmbeddingProviderError() from e + def _vectorizer_error_record( self, record: EmbeddingRecord, chunk_error: ChunkEmbeddingError ) -> VectorizerErrorRecord: From 9335e66db5a9d32e28b72ee179b40555b0f19f1c Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 17:37:17 +0100 Subject: [PATCH 02/23] feat: process batch embeddings submitted to openai --- projects/pgai/pgai/vectorizer/embeddings.py | 2 +- projects/pgai/pgai/vectorizer/vectorizer.py | 104 +++++++++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index e2adb328b..20354434b 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -472,7 +472,7 @@ async def create_and_submit_embedding_batch( for document in documents: entry = { - 'custom_id': document['id'] + '_' + document['chunk_seq'], + 'custom_id': document['pk'] + '_' + document['id'] + '_' + document['chunk_seq'], 'method': 'POST', 'url': '/v1/embeddings', 'body': { diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index a2ea260e1..418c30c92 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -1,11 +1,12 @@ import asyncio +import json import os import threading import time from collections.abc import Callable from functools import cached_property from itertools import repeat -from typing import Any, TypeAlias +from typing import Any, TypeAlias, Dict import numpy as np import psycopg @@ -17,6 +18,7 @@ from psycopg.types.json import Jsonb from pydantic.dataclasses import dataclass from pydantic.fields import Field +from openai import OpenAI from .chunking import ( LangChainCharacterTextSplitter, @@ -465,8 +467,9 @@ async def run(self) -> int: await self.vectorizer.config.embedding.setup() while True: if not self._continue_processing(loops, res): + await self._check_and_process_openai_batches(conn) # TODO how can we run this only after hitting the rate limit of the normal openai batch embedding api? - self._do_openai_batch(conn) + await self._do_openai_batch(conn) return res items_processed = await self._do_batch(conn) if items_processed == 0: @@ -564,6 +567,45 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: ) raise e + @tracer.wrap() + async def _check_and_process_openai_batches(self, conn: AsyncConnection): + async with ( + conn.transaction(), + conn.cursor() as cursor, + ): + client = OpenAI() # TODO how can I get the client? There has to be one created already that I can use? + await conn.execute("SELECT openai_batch_id, output_file_id FROM embedding_batches WHERE status not in('failed', 'processed', 'prepared')") + for batch_row in cursor.fetchall(): + batch = client.batches.retrieve(batch_row['openai_batch_id']) + + await conn.execute(""" + UPDATE embedding_batches + SET status = %s, completed_at = %s, failed_at = %s, + output_file_id = %s, errors = %s + WHERE id = %s + """, ( + batch['status'], + batch.get('completed_at'), + batch.get('failed_at'), + batch.get('output_file_id'), + Jsonb(batch.get('errors')), + batch_row['openai_batch_id'], + )) + + # batch has been processed successfully in openai, that means we can + # collect the results and store them in the database. + if batch['status'] == "completed": + await self._embed_and_write_from_batch(conn, batch, client) + + await cursor.execute(""" + UPDATE embedding_batches + SET status = %s + WHERE id = %s + """, ( + "processed", + batch_row['openai_batch_id'], + )) + @tracer.wrap() async def _do_batch(self, conn: AsyncConnection) -> int: """ @@ -715,6 +757,63 @@ async def _embed_and_write(self, conn: AsyncConnection, items: list[SourceRow]): return len(records) + @tracer.wrap() + async def _embed_and_write_from_batch( + self, + conn: AsyncConnection, + batch: Dict[str, Any], + client: OpenAI, + ): + """ + Embeds the items and writes them to the database. + + - Deletes existing embeddings for the items. + - Generates the documents to be embedded, chunks them, and formats the chunks. + - Sends the documents to the embedding provider and writes embeddings + to the database. + - Logs any non-fatal errors encountered during embedding. + + Args: + conn (AsyncConnection): The database connection. + batch: The batch as retrieved from OpenAI's api. + client: The OpenAI client to use. + + Returns: + int: The number of records written to the database. + """ + batch_file = client.files.content(batch['output_file_id']).text + + batch_data = batch_file.text.strip().split('\n') + num_records = 0 + document_chunks: Dict[int, Dict[int, str]] = {} # outer key is the document id, inner key is the chunk id, content is the chunk + all_items = [] + records: list[EmbeddingRecord] = [] + for line in batch_data: + json_line = json.loads(line) + if "custom_id" in json_line and "response" in json_line: + + custom_id = json_line['custom_id'] + pk, document_id, chunk_seq = custom_id.split('_') + embedding_data = json_line['response']['body']['data'] + + item = {pk: document_id} + all_items.append(item) + + if document_id not in document_chunks: + document_chunks[document_id] = {} + chunks = self.vectorizer.config.chunking.into_chunks(item) + + for chunk_id, chunk in enumerate(chunks, 0): + formatted = self.vectorizer.config.formatting.format(chunk, item) + document_chunks[document_id][chunk_id] = formatted + + records.append(pk + [chunk_seq, document_chunks[document_id][chunk_seq]] + [np.array(embedding_data)]) + + await self._delete_embeddings(conn, all_items) + await self._copy_embeddings(conn, records) + + return num_records + async def _delete_embeddings(self, conn: AsyncConnection, items: list[SourceRow]): """ Deletes the embeddings for the given items from the target table. @@ -832,6 +931,7 @@ async def _generate_embedding_batch( for chunk_id, chunk in enumerate(chunks, 0): formatted = self.vectorizer.config.formatting.format(chunk, item) documents.append({ + 'pk': self._get_item_pk_values(item), 'id': item['id'], 'chunk_id': chunk_id, 'chunk': formatted, From 70d7b7da2b7131d1965eb3a336c32f51e00b7da9 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:18:52 +0100 Subject: [PATCH 03/23] fix: only open temp file for writing --- projects/pgai/pgai/vectorizer/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index 20354434b..9d5dc7058 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -468,7 +468,7 @@ async def create_and_submit_embedding_batch( """ - temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl', mode='rw') + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl', mode='w') for document in documents: entry = { From a10b0fd0ea8a3712e32dbd1c749b0261e13de4bd Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:19:48 +0100 Subject: [PATCH 04/23] chore: move table creation to separate function --- projects/pgai/pgai/vectorizer/vectorizer.py | 48 +++++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 418c30c92..6b34cd66d 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -477,6 +477,34 @@ async def run(self) -> int: res += items_processed loops += 1 + async def _create_batch_table(self, conn: AsyncConnection): + # TODO this does not feel like the way to go, is there a way to do these kind of migrations properly? + await conn.execute(""" +CREATE TABLE IF NOT EXISTS ai.embedding_batches +( + openai_batch_id VARCHAR(255) PRIMARY KEY, + input_file_id VARCHAR(255) NOT NULL, + output_file_id VARCHAR(255), + status VARCHAR(255) NOT NULL, + errors JSONB, + created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(), + expires_at TIMESTAMP(0), + completed_at TIMESTAMP(0), + failed_at TIMESTAMP(0) +); + """) + await conn.execute(""" + CREATE INDEX IF NOT EXISTS embedding_batches_status_index ON ai.embedding_batches (status); + """) + return await conn.execute(""" +CREATE TABLE IF NOT EXISTS ai.embedding_batch_chunks +( + id VARCHAR(255) PRIMARY KEY, + embedding_batch_id VARCHAR(255) REFERENCES ai.embedding_batches (openai_batch_id), + text TEXT +); + """) + @tracer.wrap() async def _do_openai_batch(self, conn: AsyncConnection) -> int: """ @@ -508,25 +536,7 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: if len(items) == 0: return 0 - created_batch = await self._generate_embedding_batch(items) - - # TODO this does not feel like the way to go, is there a way to do these kind of migrations properly? - await conn.execute(""" - CREATE TABLE IF NOT EXISTS embedding_batches ( - id BIGSERIAL PRIMARY KEY, - openai_batch_id VARCHAR(255), - input_file_id VARCHAR(255) NOT NULL, - output_file_id VARCHAR(255), - status VARCHAR(255) NOT NULL, - errors JSONB, - created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(), - expires_at TIMESTAMP(0), - completed_at TIMESTAMP(0), - failed_at TIMESTAMP(0) - ); - CREATE INDEX IF NOT EXISTS embedding_batches_status_index - ON embedding_batches (status); - """) + created_batch, documents = await self._generate_embedding_batch(items) await conn.execute(""" INSERT INTO embedding_batches ( From b1be455530a7b640346dee3e0a2cfd01ace66ba7 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:21:35 +0100 Subject: [PATCH 05/23] chore: use OpenAI's batch type --- projects/pgai/pgai/vectorizer/embeddings.py | 56 ++------------------- projects/pgai/pgai/vectorizer/vectorizer.py | 10 ++-- 2 files changed, 9 insertions(+), 57 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index 9d5dc7058..87ef3ef62 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -14,7 +14,6 @@ Literal, TypeAlias, TypeVar, - Optional, ) import ollama @@ -39,52 +38,6 @@ logger = structlog.get_logger() -@dataclass -class OpenAIBatch: - """ - Represents detailed information about a batch process. - - Attributes: - id (str): The unique identifier of the batch. - endpoint (str): The OpenAI API endpoint used by the batch. - errors (object): Any errors associated with the batch. - input_file_id (str): The ID of the input file for the batch. - completion_window (str): The time frame within which the batch should be processed. - status (str): The current status of the batch. - output_file_id (str): The ID of the file containing successful request outputs. - error_file_id (str): The ID of the file containing error request outputs. - created_at (int): The Unix timestamp for when the batch was created. - in_progress_at (Optional[int]): The Unix timestamp for when the batch started processing. - expires_at (int): The Unix timestamp for when the batch will expire. - finalizing_at (Optional[int]): The Unix timestamp for when the batch started finalizing. - completed_at (Optional[int]): The Unix timestamp for when the batch was completed. - failed_at (Optional[int]): The Unix timestamp for when the batch failed. - expired_at (Optional[int]): The Unix timestamp for when the batch expired. - cancelling_at (Optional[int]): The Unix timestamp for when the batch started cancelling. - cancelled_at (Optional[int]): The Unix timestamp for when the batch was cancelled. - metadata (Mapping[str, str]): A map of metadata associated with the batch. - """ - - id: str - endpoint: str - errors: object - input_file_id: str - completion_window: str - status: str - output_file_id: str - error_file_id: str - created_at: int - expires_at: int - metadata: Mapping[str, str] - in_progress_at: Optional[int] = None - finalizing_at: Optional[int] = None - completed_at: Optional[int] = None - failed_at: Optional[int] = None - expired_at: Optional[int] = None - cancelling_at: Optional[int] = None - cancelled_at: Optional[int] = None - - @dataclass class ChunkEmbeddingError: """ @@ -455,8 +408,9 @@ async def embed( ) async def create_and_submit_embedding_batch( - self, documents: list[dict[str, Any]] - ) -> OpenAIBatch: + self, + documents: list[dict[str, Any]], + ) -> openai.types.Batch: """ Creates a batch of embeddings using OpenAI's embeddings API as outlined in https://platform.openai.com/docs/guides/batch/batch-api?lang=python @@ -491,7 +445,7 @@ async def create_and_submit_embedding_batch( purpose="batch", ) - batch = client.batches.create( + return client.batches.create( input_file_id=batch_input_file.id, endpoint="/v1/chat/completions", completion_window="24h", @@ -500,8 +454,6 @@ async def create_and_submit_embedding_batch( } ) - return OpenAIBatch(**batch.to_dict()) - async def _filter_by_length_and_embed( self, model_token_length: int, encoded_documents: list[list[int]] ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 6b34cd66d..025e911eb 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -18,13 +18,13 @@ from psycopg.types.json import Jsonb from pydantic.dataclasses import dataclass from pydantic.fields import Field -from openai import OpenAI +import openai from .chunking import ( LangChainCharacterTextSplitter, LangChainRecursiveCharacterTextSplitter, ) -from .embeddings import ChunkEmbeddingError, Ollama, OpenAI, VoyageAI, OpenAIBatch +from .embeddings import ChunkEmbeddingError, Ollama, OpenAI, VoyageAI from .formatting import ChunkValue, PythonTemplate from .processing import ProcessingDefault @@ -771,7 +771,7 @@ async def _embed_and_write(self, conn: AsyncConnection, items: list[SourceRow]): async def _embed_and_write_from_batch( self, conn: AsyncConnection, - batch: Dict[str, Any], + batch: openai.types.Batch, client: OpenAI, ): """ @@ -791,7 +791,7 @@ async def _embed_and_write_from_batch( Returns: int: The number of records written to the database. """ - batch_file = client.files.content(batch['output_file_id']).text + batch_file = client.files.content(batch.output_file_id) batch_data = batch_file.text.strip().split('\n') num_records = 0 @@ -934,7 +934,7 @@ async def _generate_embeddings( async def _generate_embedding_batch( self, items: list[SourceRow] - ) -> OpenAIBatch: + ) -> openai.types.Batch: documents: list[dict[str, Any]] = [] for item in items: chunks = self.vectorizer.config.chunking.into_chunks(item) From b0e6a08a25ee28430d7cbdb7d4753d68b191238c Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:23:16 +0100 Subject: [PATCH 06/23] feat: generate full chunk id earlier --- projects/pgai/pgai/vectorizer/embeddings.py | 2 +- projects/pgai/pgai/vectorizer/vectorizer.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index 87ef3ef62..db6e3635d 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -426,7 +426,7 @@ async def create_and_submit_embedding_batch( for document in documents: entry = { - 'custom_id': document['pk'] + '_' + document['id'] + '_' + document['chunk_seq'], + 'custom_id': document['unique_full_chunk_id'], 'method': 'POST', 'url': '/v1/embeddings', 'body': { diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 025e911eb..f77f30826 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -937,13 +937,17 @@ async def _generate_embedding_batch( ) -> openai.types.Batch: documents: list[dict[str, Any]] = [] for item in items: + pk = self._get_item_pk_values(item) chunks = self.vectorizer.config.chunking.into_chunks(item) for chunk_id, chunk in enumerate(chunks, 0): formatted = self.vectorizer.config.formatting.format(chunk, item) + unique_full_chunk_id = [ + ','.join(self.queries.pk_attnames), + ','.join(map(str, pk)), + str(chunk_id), + ] documents.append({ - 'pk': self._get_item_pk_values(item), - 'id': item['id'], - 'chunk_id': chunk_id, + 'unique_full_chunk_id': ':::'.join(unique_full_chunk_id), 'chunk': formatted, }) From 9c3e01745235e4aaea0f25bf5013862f1309f055 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:23:45 +0100 Subject: [PATCH 07/23] fix: correctly use embeddings endpoint --- projects/pgai/pgai/vectorizer/embeddings.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index db6e3635d..a00a1424d 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -430,7 +430,7 @@ async def create_and_submit_embedding_batch( 'method': 'POST', 'url': '/v1/embeddings', 'body': { - 'model': 'text-embedding-3-large', # TODO how can I use the configured embeddings model? + 'model': 'text-embedding-3-small', # TODO how can I use the configured embeddings model? 'input': document['chunk'], }, } @@ -447,11 +447,8 @@ async def create_and_submit_embedding_batch( return client.batches.create( input_file_id=batch_input_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={ - "description": "nightly eval job" - } + endpoint='/v1/embeddings', + completion_window='24h', ) async def _filter_by_length_and_embed( From 866e04129357ff23f1788469e3443796e6250d2e Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:24:21 +0100 Subject: [PATCH 08/23] fix: properly convert time --- projects/pgai/pgai/vectorizer/vectorizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index f77f30826..558b8049f 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -4,6 +4,7 @@ import threading import time from collections.abc import Callable +from datetime import datetime, timezone from functools import cached_property from itertools import repeat from typing import Any, TypeAlias, Dict @@ -539,14 +540,13 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: created_batch, documents = await self._generate_embedding_batch(items) await conn.execute(""" - INSERT INTO embedding_batches ( + INSERT INTO ai.embedding_batches ( openai_batch_id, input_file_id, output_file_id, status, errors, - expires_at, - completed_at + expires_at ) VALUES ( %(openai_batch_id)s, %(input_file_id)s, @@ -561,7 +561,7 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: 'output_file_id': created_batch.output_file_id, 'status': created_batch.status, 'errors': created_batch.errors, - 'expires_at': created_batch.expires_at, + 'expires_at': datetime.fromtimestamp(created_batch.expires_at, timezone.utc), }) return len(items) From 2e57a35a1a7210d7d8188b2419d1f8070db9e651 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:24:44 +0100 Subject: [PATCH 09/23] feat: insert all chunks into the db after batch creation --- projects/pgai/pgai/vectorizer/vectorizer.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 558b8049f..a548253e1 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -564,6 +564,25 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: 'expires_at': datetime.fromtimestamp(created_batch.expires_at, timezone.utc), }) + for doc in documents: + await conn.execute(""" + INSERT INTO ai.embedding_batch_chunks ( + id, + embedding_batch_id, + text + ) VALUES ( + %(id)s, + %(embedding_batch_id)s, + %(text)s + ) + """, { + 'id': doc['unique_full_chunk_id'], + 'embedding_batch_id': created_batch.id, + 'text': doc['chunk'] + }) + + # TODO how to delete submitted entries from the queue? + return len(items) except Exception as e: async with conn.transaction(): From 5ef0491965ee069b04edd3b6730a62898350413e Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:25:50 +0100 Subject: [PATCH 10/23] fix: correctly process batches --- projects/pgai/pgai/vectorizer/vectorizer.py | 84 +++++++++++---------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index a548253e1..a7a21d44e 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -602,37 +602,41 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): conn.transaction(), conn.cursor() as cursor, ): - client = OpenAI() # TODO how can I get the client? There has to be one created already that I can use? - await conn.execute("SELECT openai_batch_id, output_file_id FROM embedding_batches WHERE status not in('failed', 'processed', 'prepared')") - for batch_row in cursor.fetchall(): - batch = client.batches.retrieve(batch_row['openai_batch_id']) + client = openai.OpenAI() # TODO how can I get the client? There has to be one created already that I can use? + await cursor.execute("SELECT openai_batch_id, output_file_id FROM ai.embedding_batches WHERE status not in('failed', 'processed', 'prepared')") + for batch_row in await cursor.fetchall(): + batch = client.batches.retrieve(batch_row[0]) await conn.execute(""" - UPDATE embedding_batches - SET status = %s, completed_at = %s, failed_at = %s, - output_file_id = %s, errors = %s - WHERE id = %s + UPDATE ai.embedding_batches + SET + status = %s, + completed_at = %s, + failed_at = %s, + output_file_id = %s, + errors = %s + WHERE embedding_batches.openai_batch_id = %s """, ( - batch['status'], - batch.get('completed_at'), - batch.get('failed_at'), - batch.get('output_file_id'), - Jsonb(batch.get('errors')), - batch_row['openai_batch_id'], + batch.status, + datetime.fromtimestamp(batch.completed_at, timezone.utc) if batch.completed_at else None, + datetime.fromtimestamp(batch.failed_at, timezone.utc) if batch.failed_at else None, + batch.output_file_id, + Jsonb(batch.errors), + batch_row[0], )) # batch has been processed successfully in openai, that means we can # collect the results and store them in the database. - if batch['status'] == "completed": + if batch.status == 'completed': await self._embed_and_write_from_batch(conn, batch, client) await cursor.execute(""" - UPDATE embedding_batches + UPDATE ai.embedding_batches SET status = %s - WHERE id = %s + WHERE openai_batch_id = %s """, ( - "processed", - batch_row['openai_batch_id'], + 'processed', + batch_row[0], )) @tracer.wrap() @@ -814,32 +818,36 @@ async def _embed_and_write_from_batch( batch_data = batch_file.text.strip().split('\n') num_records = 0 - document_chunks: Dict[int, Dict[int, str]] = {} # outer key is the document id, inner key is the chunk id, content is the chunk all_items = [] - records: list[EmbeddingRecord] = [] - for line in batch_data: - json_line = json.loads(line) - if "custom_id" in json_line and "response" in json_line: - - custom_id = json_line['custom_id'] - pk, document_id, chunk_seq = custom_id.split('_') - embedding_data = json_line['response']['body']['data'] + all_records: list[EmbeddingRecord] = [] - item = {pk: document_id} - all_items.append(item) + # Fetch all chunks from ai.embedding_batch_chunks where the embedding_batch_id is batch.id + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT id, text FROM ai.embedding_batch_chunks WHERE embedding_batch_id = %s", + (batch.id,) + ) + embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()} + + for line in batch_data: + json_line = json.loads(line) + if "custom_id" in json_line and "response" in json_line: - if document_id not in document_chunks: - document_chunks[document_id] = {} - chunks = self.vectorizer.config.chunking.into_chunks(item) + custom_id = json_line['custom_id'] + pk_names, document_id, chunk_seq = custom_id.split(':::') + embedding_data = json_line['response']['body']['data'][0]['embedding'] - for chunk_id, chunk in enumerate(chunks, 0): - formatted = self.vectorizer.config.formatting.format(chunk, item) - document_chunks[document_id][chunk_id] = formatted + resolved_id = document_id.split(',') + resolved_pk = pk_names.split(',') + item = {pk: id_value for pk, id_value in zip(resolved_pk, resolved_id)} + item[self.vectorizer.config.chunking.chunk_column] = embedding_batch_chunks[custom_id] - records.append(pk + [chunk_seq, document_chunks[document_id][chunk_seq]] + [np.array(embedding_data)]) + all_items.append(item) + all_records.append([resolved_id + [chunk_seq, embedding_batch_chunks[custom_id]] + [np.array(embedding_data)]]) await self._delete_embeddings(conn, all_items) - await self._copy_embeddings(conn, records) + for records in all_records: + await self._copy_embeddings(conn, records) return num_records From fa1010886795fcf07c58c97a12973ce352f45b51 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 5 Dec 2024 20:27:18 +0100 Subject: [PATCH 11/23] fix: return documents --- projects/pgai/pgai/vectorizer/vectorizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index a7a21d44e..5c19f06e8 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -961,7 +961,7 @@ async def _generate_embeddings( async def _generate_embedding_batch( self, items: list[SourceRow] - ) -> openai.types.Batch: + ) -> tuple[openai.types.Batch, list[dict[str, Any]]]: documents: list[dict[str, Any]] = [] for item in items: pk = self._get_item_pk_values(item) @@ -979,7 +979,8 @@ async def _generate_embedding_batch( }) try: - return await self.vectorizer.config.embedding.create_and_submit_embedding_batch(documents) + batch = await self.vectorizer.config.embedding.create_and_submit_embedding_batch(documents) + return batch, documents except Exception as e: raise EmbeddingProviderError() from e From 783a62f01121eca102014077ea9cee6019a495f5 Mon Sep 17 00:00:00 2001 From: kolaente Date: Mon, 9 Dec 2024 12:54:25 +0100 Subject: [PATCH 12/23] fix: use configured embeddings model --- projects/pgai/pgai/vectorizer/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index a00a1424d..ed002779f 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -430,7 +430,7 @@ async def create_and_submit_embedding_batch( 'method': 'POST', 'url': '/v1/embeddings', 'body': { - 'model': 'text-embedding-3-small', # TODO how can I use the configured embeddings model? + 'model': self.model, 'input': document['chunk'], }, } From 1be82acff8204616aafcda6a9fc628f83b51a5bc Mon Sep 17 00:00:00 2001 From: kolaente Date: Mon, 9 Dec 2024 12:59:51 +0100 Subject: [PATCH 13/23] chore: rename write embeddings function --- projects/pgai/pgai/vectorizer/vectorizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 5c19f06e8..6255ae106 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -628,7 +628,7 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): # batch has been processed successfully in openai, that means we can # collect the results and store them in the database. if batch.status == 'completed': - await self._embed_and_write_from_batch(conn, batch, client) + await self._write_embeddings_from_batch(conn, batch, client) await cursor.execute(""" UPDATE ai.embedding_batches @@ -791,7 +791,7 @@ async def _embed_and_write(self, conn: AsyncConnection, items: list[SourceRow]): return len(records) @tracer.wrap() - async def _embed_and_write_from_batch( + async def _write_embeddings_from_batch( self, conn: AsyncConnection, batch: openai.types.Batch, From f7f6d13befc6f3455a1bed6293f707e1ed6b04fa Mon Sep 17 00:00:00 2001 From: kolaente Date: Mon, 9 Dec 2024 13:01:44 +0100 Subject: [PATCH 14/23] chore: adjust function comment --- projects/pgai/pgai/vectorizer/vectorizer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 6255ae106..0ba2fea0b 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -798,21 +798,17 @@ async def _write_embeddings_from_batch( client: OpenAI, ): """ - Embeds the items and writes them to the database. + Writes embeddings from an OpenAI batch embedding to the database. - Deletes existing embeddings for the items. - - Generates the documents to be embedded, chunks them, and formats the chunks. - - Sends the documents to the embedding provider and writes embeddings - to the database. + - Loads created embeddings from the batch. + - Writes created embeddings to the database. - Logs any non-fatal errors encountered during embedding. Args: conn (AsyncConnection): The database connection. batch: The batch as retrieved from OpenAI's api. client: The OpenAI client to use. - - Returns: - int: The number of records written to the database. """ batch_file = client.files.content(batch.output_file_id) From d49c3524772958d21ea31faf2d047fa3e306cb04 Mon Sep 17 00:00:00 2001 From: kolaente Date: Mon, 9 Dec 2024 13:10:40 +0100 Subject: [PATCH 15/23] feat: create batch embedding tables in extension --- .../sql/idempotent/013-vectorizer-api.sql | 34 ++++++- .../sql/idempotent/016-openai-batch-api.sql | 96 +++++++++++++++++++ projects/pgai/pgai/vectorizer/vectorizer.py | 28 ------ 3 files changed, 128 insertions(+), 30 deletions(-) create mode 100644 projects/extension/sql/idempotent/016-openai-batch-api.sql diff --git a/projects/extension/sql/idempotent/013-vectorizer-api.sql b/projects/extension/sql/idempotent/013-vectorizer-api.sql index b1fa34dd4..498fd159e 100644 --- a/projects/extension/sql/idempotent/013-vectorizer-api.sql +++ b/projects/extension/sql/idempotent/013-vectorizer-api.sql @@ -1,5 +1,3 @@ - - ------------------------------------------------------------------------------- -- execute_vectorizer create or replace function ai.execute_vectorizer(vectorizer_id pg_catalog.int4) returns void @@ -31,6 +29,9 @@ create or replace function ai.create_vectorizer , queue_table pg_catalog.name default null , grant_to pg_catalog.name[] default ai.grant_to() , enqueue_existing pg_catalog.bool default true +, embedding_batch_schema pg_catalog.name default null +, embedding_batch_table pg_catalog.name default null +, embedding_batch_chunks_table pg_catalog.name default null ) returns pg_catalog.int4 as $func$ declare @@ -44,6 +45,7 @@ declare _vectorizer_id pg_catalog.int4; _sql pg_catalog.text; _job_id pg_catalog.int8; + _implementation pg_catalog.text; begin -- make sure all the roles listed in grant_to exist if grant_to is not null then @@ -225,6 +227,31 @@ begin scheduling = pg_catalog.jsonb_insert(scheduling, array['job_id'], pg_catalog.to_jsonb(_job_id)); end if; + embedding_batch_schema = coalesce(embedding_batch_schema, 'ai'); + embedding_batch_table = coalesce(embedding_batch_table, pg_catalog.concat('_vectorizer_embedding_batches_', _vectorizer_id)); + embedding_batch_chunks_table = coalesce(embedding_batch_chunks_table, pg_catalog.concat('_vectorizer_embedding_batch_chunks_', _vectorizer_id)); + + -- create batch embedding tables + select (embedding operator (pg_catalog.->> 'implementation'))::text into _implementation; + if _implementation = 'openai' then + -- make sure embedding batch table name is available + if pg_catalog.to_regclass(pg_catalog.format('%I.%I', embedding_batch_schema, embedding_batch_table)) is not null then + raise exception 'an object named %.% already exists. specify an alternate embedding_batch_table explicitly', queue_schema, queue_table; + end if; + + -- make sure embedding batch chunks table name is available + if pg_catalog.to_regclass(pg_catalog.format('%I.%I', embedding_batch_schema, embedding_batch_chunks_table)) is not null then + raise exception 'an object named %.% already exists. specify an alternate embedding_batch_chunks_table explicitly', queue_schema, queue_table; + end if; + + perform ai._vectorizer_create_embedding_batches_table + (embedding_batch_schema + , embedding_batch_table + , embedding_batch_chunks_table + , grant_to + ); + end if; + insert into ai.vectorizer ( id , source_schema @@ -259,6 +286,9 @@ begin , 'formatting', formatting , 'scheduling', scheduling , 'processing', processing + , 'embedding_batch_schema', embedding_batch_schema + , 'embedding_batch_table', embedding_batch_table + , 'embedding_batch_chunks_table', embedding_batch_chunks_table ) ); diff --git a/projects/extension/sql/idempotent/016-openai-batch-api.sql b/projects/extension/sql/idempotent/016-openai-batch-api.sql new file mode 100644 index 000000000..03f5a0211 --- /dev/null +++ b/projects/extension/sql/idempotent/016-openai-batch-api.sql @@ -0,0 +1,96 @@ +------------------------------------------------------------------------------- +-- _vectorizer_create_queue_table +create or replace function ai._vectorizer_create_embedding_batches_table +( embedding_batch_schema name +, embedding_batch_table name +, embedding_batch_chunks_table name +, grant_to name[] +) returns void as +$func$ +declare + _sql text; +begin + -- create the batches table + select pg_catalog.format + ( $sql$create table %I.%I( + openai_batch_id VARCHAR(255) PRIMARY KEY, + input_file_id VARCHAR(255) NOT NULL, + output_file_id VARCHAR(255), + status VARCHAR(255) NOT NULL, + errors JSONB, + created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(), + expires_at TIMESTAMP(0), + completed_at TIMESTAMP(0), + failed_at TIMESTAMP(0) +))$sql$ + , embedding_batch_schema + , embedding_batch_table + ) into strict _sql + ; + execute _sql; + + -- create the index + select pg_catalog.format + ( $sql$create index on %I.%I (status)$sql$ + , embedding_batch_schema, embedding_batch_table + ) into strict _sql + ; + execute _sql; + + -- create the batch chunks table + select pg_catalog.format + ( $sql$create table %I.%I( + id VARCHAR(255) PRIMARY KEY, + embedding_batch_id VARCHAR(255) REFERENCES %I.%I (openai_batch_id), + text TEXT +))$sql$ + , embedding_batch_schema + , embedding_batch_chunks_table + , embedding_batch_schema + , embedding_batch_table + ) into strict _sql + ; + execute _sql; + + if grant_to is not null then + -- grant usage on queue schema to grant_to roles + select pg_catalog.format + ( $sql$grant usage on schema %I to %s$sql$ + , embedding_batch_schema + , ( + select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ') + from pg_catalog.unnest(grant_to) x + ) + ) into strict _sql; + execute _sql; + + -- grant select, update, delete on batches table to grant_to roles + select pg_catalog.format + ( $sql$grant select, insert, update, delete on %I.%I to %s$sql$ + , embedding_batch_schema + , embedding_batch_table + , ( + select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ') + from pg_catalog.unnest(grant_to) x + ) + ) into strict _sql; + execute _sql; + + -- grant select, update, delete on batch chunks table to grant_to roles + select pg_catalog.format + ( $sql$grant select, insert, update, delete on %I.%I to %s$sql$ + , embedding_batch_schema + , embedding_batch_chunks_table + , ( + select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ') + from pg_catalog.unnest(grant_to) x + ) + ) into strict _sql; + execute _sql; + end if; +end; +$func$ + language plpgsql volatile security invoker + set search_path to pg_catalog, pg_temp +; + diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 0ba2fea0b..2935186af 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -478,34 +478,6 @@ async def run(self) -> int: res += items_processed loops += 1 - async def _create_batch_table(self, conn: AsyncConnection): - # TODO this does not feel like the way to go, is there a way to do these kind of migrations properly? - await conn.execute(""" -CREATE TABLE IF NOT EXISTS ai.embedding_batches -( - openai_batch_id VARCHAR(255) PRIMARY KEY, - input_file_id VARCHAR(255) NOT NULL, - output_file_id VARCHAR(255), - status VARCHAR(255) NOT NULL, - errors JSONB, - created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(), - expires_at TIMESTAMP(0), - completed_at TIMESTAMP(0), - failed_at TIMESTAMP(0) -); - """) - await conn.execute(""" - CREATE INDEX IF NOT EXISTS embedding_batches_status_index ON ai.embedding_batches (status); - """) - return await conn.execute(""" -CREATE TABLE IF NOT EXISTS ai.embedding_batch_chunks -( - id VARCHAR(255) PRIMARY KEY, - embedding_batch_id VARCHAR(255) REFERENCES ai.embedding_batches (openai_batch_id), - text TEXT -); - """) - @tracer.wrap() async def _do_openai_batch(self, conn: AsyncConnection) -> int: """ From cdad4bca1171de077cc9ed2a76ed3dadb5bdbf74 Mon Sep 17 00:00:00 2001 From: kolaente Date: Mon, 9 Dec 2024 14:36:49 +0100 Subject: [PATCH 16/23] feat: move all queries to cached properties --- projects/pgai/pgai/vectorizer/vectorizer.py | 156 +++++++++++++------- 1 file changed, 99 insertions(+), 57 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 2935186af..1844bb4ca 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -74,6 +74,9 @@ class Config: processing: Processing settings such as batch size and concurrency. chunking: The chunking strategy. formatting: Formatting strategy to apply to the chunks. + embedding_batch_schema: The schema where the embedding batches are stored. + embedding_batch_table: The table where the embedding batches are stored. + embedding_batch_chunks_table: The table where the embedding batch chunks are stored. """ version: str @@ -83,6 +86,9 @@ class Config: LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter ) = Field(..., discriminator="implementation") formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation") + embedding_batch_schema: str | None + embedding_batch_table: str | None + embedding_batch_chunks_table: str | None @dataclass @@ -321,6 +327,82 @@ def insert_errors_query(self) -> sql.Composed: self.errors_table_ident, ) + @cached_property + def fetch_batches_to_process_query(self) -> sql.Composed: + return sql.SQL( + "SELECT openai_batch_id, output_file_id FROM {}.{} WHERE status not in('failed', 'processed', 'prepared')" + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def update_batch_embedding_query(self) -> sql.Composed: + return sql.SQL( + "UPDATE {}.{} SET status = %s, completed_at = %s, failed_at = %s, output_file_id = %s, errors = %s WHERE openai_batch_id = %s" + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def update_batch_embedding_status_query(self) -> sql.Composed: + return sql.SQL( + "UPDATE {}.{} SET status = %s WHERE openai_batch_id = %s" + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def fetch_chunks_for_batch_id_query(self) -> sql.Composed: + return sql.SQL( + "SELECT id, text FROM {}.{} WHERE embedding_batch_id = %s", + ).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_chunks_table, + ) + + @cached_property + def insert_batch_embedding_query(self) -> sql.Composed: + return sql.SQL(""" + INSERT INTO {}.{} ( + openai_batch_id, + input_file_id, + output_file_id, + status, + errors, + expires_at + ) VALUES ( + %s, + %s, + %s, + %s, + %s, + %s + ) + """).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_table, + ) + + @cached_property + def insert_batch_embedding_chunks_query(self) -> sql.Composed: + return sql.SQL(""" + INSERT INTO {}.{} ( + id, + embedding_batch_id, + text + ) VALUES ( + %s, + %s, + %s + ) + """).format( + self.vectorizer.config.embedding_batch_schema, + self.vectorizer.config.embedding_batch_chunks_table, + ) + def _pks_placeholders_tuples(self, items_count: int) -> sql.Composed: """Generates a comma separated list of tuples with placeholders for the primary key fields of the source table. @@ -511,47 +593,21 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: created_batch, documents = await self._generate_embedding_batch(items) - await conn.execute(""" - INSERT INTO ai.embedding_batches ( - openai_batch_id, - input_file_id, - output_file_id, - status, - errors, - expires_at - ) VALUES ( - %(openai_batch_id)s, - %(input_file_id)s, - %(output_file_id)s, - %(status)s, - %(errors)s, - %(expires_at)s - ) - """, { - 'openai_batch_id': created_batch.id, - 'input_file_id': created_batch.input_file_id, - 'output_file_id': created_batch.output_file_id, - 'status': created_batch.status, - 'errors': created_batch.errors, - 'expires_at': datetime.fromtimestamp(created_batch.expires_at, timezone.utc), - }) + await conn.execute(self.queries.insert_batch_embedding_query, ( + created_batch.id, + created_batch.input_file_id, + created_batch.output_file_id, + created_batch.status, + created_batch.errors, + datetime.fromtimestamp(created_batch.expires_at, timezone.utc), + )) for doc in documents: - await conn.execute(""" - INSERT INTO ai.embedding_batch_chunks ( - id, - embedding_batch_id, - text - ) VALUES ( - %(id)s, - %(embedding_batch_id)s, - %(text)s - ) - """, { - 'id': doc['unique_full_chunk_id'], - 'embedding_batch_id': created_batch.id, - 'text': doc['chunk'] - }) + await conn.execute(self.queries.insert_batch_embedding_chunks_query, ( + doc['unique_full_chunk_id'], + created_batch.id, + doc['chunk'] + )) # TODO how to delete submitted entries from the queue? @@ -575,20 +631,11 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): conn.cursor() as cursor, ): client = openai.OpenAI() # TODO how can I get the client? There has to be one created already that I can use? - await cursor.execute("SELECT openai_batch_id, output_file_id FROM ai.embedding_batches WHERE status not in('failed', 'processed', 'prepared')") + await cursor.execute(self.queries.fetch_batches_to_process_query) for batch_row in await cursor.fetchall(): batch = client.batches.retrieve(batch_row[0]) - await conn.execute(""" - UPDATE ai.embedding_batches - SET - status = %s, - completed_at = %s, - failed_at = %s, - output_file_id = %s, - errors = %s - WHERE embedding_batches.openai_batch_id = %s - """, ( + await conn.execute(self.queries.update_batch_embedding_query, ( batch.status, datetime.fromtimestamp(batch.completed_at, timezone.utc) if batch.completed_at else None, datetime.fromtimestamp(batch.failed_at, timezone.utc) if batch.failed_at else None, @@ -602,11 +649,7 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): if batch.status == 'completed': await self._write_embeddings_from_batch(conn, batch, client) - await cursor.execute(""" - UPDATE ai.embedding_batches - SET status = %s - WHERE openai_batch_id = %s - """, ( + await cursor.execute(self.queries.update_batch_embedding_status_query, ( 'processed', batch_row[0], )) @@ -789,10 +832,9 @@ async def _write_embeddings_from_batch( all_items = [] all_records: list[EmbeddingRecord] = [] - # Fetch all chunks from ai.embedding_batch_chunks where the embedding_batch_id is batch.id async with conn.cursor() as cursor: await cursor.execute( - "SELECT id, text FROM ai.embedding_batch_chunks WHERE embedding_batch_id = %s", + self.queries.fetch_chunks_for_batch_id_query (batch.id,) ) embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()} From ba6e17951233a09d3973831636f58978c40c6373 Mon Sep 17 00:00:00 2001 From: kolaente Date: Wed, 18 Dec 2024 18:45:24 +0100 Subject: [PATCH 17/23] fix: move batch embedding changes to openai embedder --- .../pgai/pgai/vectorizer/embedders/openai.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/projects/pgai/pgai/vectorizer/embedders/openai.py b/projects/pgai/pgai/vectorizer/embedders/openai.py index 7c4e9aa5c..077b877a7 100644 --- a/projects/pgai/pgai/vectorizer/embedders/openai.py +++ b/projects/pgai/pgai/vectorizer/embedders/openai.py @@ -1,4 +1,6 @@ +import json import re +import tempfile from collections.abc import Sequence from functools import cached_property from typing import Any, Literal @@ -129,6 +131,50 @@ async def embed( model_token_length, encoded_documents ) + async def create_and_submit_embedding_batch( + self, + documents: list[dict[str, Any]], + ) -> openai.types.Batch: + """ + Creates a batch of embeddings using OpenAI's embeddings API as outlined in + https://platform.openai.com/docs/guides/batch/batch-api?lang=python + + Args: + documents (list[str]): A list of document chunks to be embedded. + + Returns: + + """ + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl', mode='w') + + for document in documents: + entry = { + 'custom_id': document['unique_full_chunk_id'], + 'method': 'POST', + 'url': '/v1/embeddings', + 'body': { + 'model': self.model, + 'input': document['chunk'], + }, + } + temp_file.write(json.dumps(entry) + '\n') + + temp_file.close() + + client = openai.OpenAI() # TODO there has to be a client already which I could use instead? + + batch_input_file = client.files.create( + file=open(temp_file.name, "rb"), + purpose="batch", + ) + + return client.batches.create( + input_file_id=batch_input_file.id, + endpoint='/v1/embeddings', + completion_window='24h', + ) + async def _filter_by_length_and_embed( self, model_token_length: int, encoded_documents: list[list[int]] ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: From 9b4f3c358596d83fe402bb0dafed469afe64c6c1 Mon Sep 17 00:00:00 2001 From: kolaente Date: Wed, 18 Dec 2024 19:05:55 +0100 Subject: [PATCH 18/23] fix: lint issues --- .../pgai/pgai/vectorizer/embedders/openai.py | 42 +++++----- projects/pgai/pgai/vectorizer/embeddings.py | 2 - projects/pgai/pgai/vectorizer/vectorizer.py | 84 +++++++++++-------- 3 files changed, 72 insertions(+), 56 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/embedders/openai.py b/projects/pgai/pgai/vectorizer/embedders/openai.py index 077b877a7..abb05c41a 100644 --- a/projects/pgai/pgai/vectorizer/embedders/openai.py +++ b/projects/pgai/pgai/vectorizer/embedders/openai.py @@ -146,33 +146,33 @@ async def create_and_submit_embedding_batch( """ - temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl', mode='w') - - for document in documents: - entry = { - 'custom_id': document['unique_full_chunk_id'], - 'method': 'POST', - 'url': '/v1/embeddings', - 'body': { - 'model': self.model, - 'input': document['chunk'], - }, - } - temp_file.write(json.dumps(entry) + '\n') - - temp_file.close() + with tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl", mode="w") as temp_file: + for document in documents: + entry = { + "custom_id": document["unique_full_chunk_id"], + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": self.model, + "input": document["chunk"], + }, + } + temp_file.write(json.dumps(entry) + "\n") + + temp_file.close() client = openai.OpenAI() # TODO there has to be a client already which I could use instead? - batch_input_file = client.files.create( - file=open(temp_file.name, "rb"), - purpose="batch", - ) + with open(temp_file.name, "rb") as file: + batch_input_file = client.files.create( + file=file, + purpose="batch", + ) return client.batches.create( input_file_id=batch_input_file.id, - endpoint='/v1/embeddings', - completion_window='24h', + endpoint="/v1/embeddings", + completion_window="24h", ) async def _filter_by_length_and_embed( diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index 3f537d32f..a40f91f02 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -1,7 +1,5 @@ import math import time -import json -import tempfile from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 0df4ae798..6f6aa36a3 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -7,9 +7,10 @@ from datetime import datetime, timezone from functools import cached_property from itertools import repeat -from typing import Any, TypeAlias, Dict +from typing import Any, TypeAlias import numpy as np +import openai import psycopg import structlog from ddtrace import tracer @@ -19,7 +20,6 @@ from psycopg.types.json import Jsonb from pydantic.dataclasses import dataclass from pydantic.fields import Field -import openai from .chunking import ( LangChainCharacterTextSplitter, @@ -330,18 +330,25 @@ def insert_errors_query(self) -> sql.Composed: @cached_property def fetch_batches_to_process_query(self) -> sql.Composed: - return sql.SQL( - "SELECT openai_batch_id, output_file_id FROM {}.{} WHERE status not in('failed', 'processed', 'prepared')" - ).format( + return sql.SQL(""" + SELECT openai_batch_id, output_file_id FROM {}.{} + WHERE status not in('failed', 'processed', 'prepared') + """).format( self.vectorizer.config.embedding_batch_schema, self.vectorizer.config.embedding_batch_table, ) @cached_property def update_batch_embedding_query(self) -> sql.Composed: - return sql.SQL( - "UPDATE {}.{} SET status = %s, completed_at = %s, failed_at = %s, output_file_id = %s, errors = %s WHERE openai_batch_id = %s" - ).format( + return sql.SQL(""" + UPDATE {}.{} SET + status = %s + completed_at = %s, + failed_at = %s, + output_file_id = %s, + errors = %s + WHERE openai_batch_id = %s + """).format( self.vectorizer.config.embedding_batch_schema, self.vectorizer.config.embedding_batch_table, ) @@ -604,11 +611,13 @@ async def _do_openai_batch(self, conn: AsyncConnection) -> int: )) for doc in documents: - await conn.execute(self.queries.insert_batch_embedding_chunks_query, ( - doc['unique_full_chunk_id'], - created_batch.id, - doc['chunk'] - )) + await conn.execute( + self.queries.insert_batch_embedding_chunks_query, + ( + doc["unique_full_chunk_id"], + created_batch.id, + doc["chunk"] + )) # TODO how to delete submitted entries from the queue? @@ -631,15 +640,18 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): conn.transaction(), conn.cursor() as cursor, ): - client = openai.OpenAI() # TODO how can I get the client? There has to be one created already that I can use? + # TODO how can I get the client? There has to be one created already that I can use? + client = openai.OpenAI() await cursor.execute(self.queries.fetch_batches_to_process_query) for batch_row in await cursor.fetchall(): batch = client.batches.retrieve(batch_row[0]) await conn.execute(self.queries.update_batch_embedding_query, ( batch.status, - datetime.fromtimestamp(batch.completed_at, timezone.utc) if batch.completed_at else None, - datetime.fromtimestamp(batch.failed_at, timezone.utc) if batch.failed_at else None, + datetime.fromtimestamp(batch.completed_at, timezone.utc) + if batch.completed_at else None, + datetime.fromtimestamp(batch.failed_at, timezone.utc) + if batch.failed_at else None, batch.output_file_id, Jsonb(batch.errors), batch_row[0], @@ -647,13 +659,15 @@ async def _check_and_process_openai_batches(self, conn: AsyncConnection): # batch has been processed successfully in openai, that means we can # collect the results and store them in the database. - if batch.status == 'completed': + if batch.status == "completed": await self._write_embeddings_from_batch(conn, batch, client) - await cursor.execute(self.queries.update_batch_embedding_status_query, ( - 'processed', - batch_row[0], - )) + await cursor.execute( + self.queries.update_batch_embedding_status_query, + ( + "processed", + batch_row[0], + )) @tracer.wrap() async def _do_batch(self, conn: AsyncConnection) -> int: @@ -828,7 +842,7 @@ async def _write_embeddings_from_batch( """ batch_file = client.files.content(batch.output_file_id) - batch_data = batch_file.text.strip().split('\n') + batch_data = batch_file.text.strip().split("\n") num_records = 0 all_items = [] all_records: list[EmbeddingRecord] = [] @@ -844,17 +858,21 @@ async def _write_embeddings_from_batch( json_line = json.loads(line) if "custom_id" in json_line and "response" in json_line: - custom_id = json_line['custom_id'] - pk_names, document_id, chunk_seq = custom_id.split(':::') - embedding_data = json_line['response']['body']['data'][0]['embedding'] + custom_id = json_line["custom_id"] + pk_names, document_id, chunk_seq = custom_id.split(":::") + embedding_data = json_line["response"]["body"]["data"][0]["embedding"] - resolved_id = document_id.split(',') - resolved_pk = pk_names.split(',') - item = {pk: id_value for pk, id_value in zip(resolved_pk, resolved_id)} + resolved_id = document_id.split(",") + resolved_pk = pk_names.split(",") + item = {pk: id_value + for pk, id_value in zip(resolved_pk, resolved_id, strict=False)} item[self.vectorizer.config.chunking.chunk_column] = embedding_batch_chunks[custom_id] all_items.append(item) - all_records.append([resolved_id + [chunk_seq, embedding_batch_chunks[custom_id]] + [np.array(embedding_data)]]) + all_records.append([ + resolved_id + + [chunk_seq, embedding_batch_chunks[custom_id]] + + [np.array(embedding_data)]]) await self._delete_embeddings(conn, all_items) for records in all_records: @@ -980,13 +998,13 @@ async def _generate_embedding_batch( for chunk_id, chunk in enumerate(chunks, 0): formatted = self.vectorizer.config.formatting.format(chunk, item) unique_full_chunk_id = [ - ','.join(self.queries.pk_attnames), - ','.join(map(str, pk)), + ",".join(self.queries.pk_attnames), + ",".join(map(str, pk)), str(chunk_id), ] documents.append({ - 'unique_full_chunk_id': ':::'.join(unique_full_chunk_id), - 'chunk': formatted, + "unique_full_chunk_id": ":::".join(unique_full_chunk_id), + "chunk": formatted, }) try: From 9af30c1f62688e287fd575a4caa5c61c2f61c743 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 19 Dec 2024 15:11:31 +0100 Subject: [PATCH 19/23] fix: move batch embedding tables creation to embedding functions --- .../sql/idempotent/008-embedding.sql | 32 +++++++++++++++++++ .../sql/idempotent/013-vectorizer-api.sql | 24 ++------------ .../pgai/pgai/vectorizer/embedders/openai.py | 8 +++++ projects/pgai/pgai/vectorizer/vectorizer.py | 30 +++++++---------- 4 files changed, 54 insertions(+), 40 deletions(-) diff --git a/projects/extension/sql/idempotent/008-embedding.sql b/projects/extension/sql/idempotent/008-embedding.sql index e03e21ed5..90703e2e6 100644 --- a/projects/extension/sql/idempotent/008-embedding.sql +++ b/projects/extension/sql/idempotent/008-embedding.sql @@ -6,8 +6,20 @@ create or replace function ai.embedding_openai , dimensions pg_catalog.int4 , chat_user pg_catalog.text default null , api_key_name pg_catalog.text default 'OPENAI_API_KEY' +, use_batch_api pg_catalog.bool default false +, embedding_batch_schema pg_catalog.name default null +, embedding_batch_table pg_catalog.name default null +, embedding_batch_chunks_table pg_catalog.name default null ) returns pg_catalog.jsonb as $func$ +declare + _vectorizer_id pg_catalog.int4; +begin + _vectorizer_id = pg_catalog.nextval('ai.vectorizer_id_seq'::pg_catalog.regclass); + embedding_batch_schema = coalesce(embedding_batch_schema, 'ai'); + embedding_batch_table = coalesce(embedding_batch_table, pg_catalog.concat('_vectorizer_embedding_batches_', _vectorizer_id)); + embedding_batch_chunks_table = coalesce(embedding_batch_chunks_table, pg_catalog.concat('_vectorizer_embedding_batch_chunks_', _vectorizer_id)); + select json_object ( 'implementation': 'openai' , 'config_type': 'embedding' @@ -15,6 +27,10 @@ as $func$ , 'dimensions': dimensions , 'user': chat_user , 'api_key_name': api_key_name + , 'use_batch_api': use_batch_api + , 'embedding_batch_schema': embedding_batch_schema + , 'embedding_batch_table': embedding_batch_table + , 'embedding_batch_chunks_table': embedding_batch_chunks_table absent on null ) $func$ language sql immutable security invoker @@ -81,6 +97,9 @@ as $func$ declare _config_type pg_catalog.text; _implementation pg_catalog.text; + _embedding_batch_schema pg_catalog.text; + _embedding_batch_table pg_catalog.text; + _embedding_batch_chunks_table pg_catalog.text; begin if pg_catalog.jsonb_typeof(config) operator(pg_catalog.!=) 'object' then raise exception 'embedding config is not a jsonb object'; @@ -93,6 +112,19 @@ begin _implementation = config operator(pg_catalog.->>) 'implementation'; case _implementation when 'openai' then + -- make sure embedding batch table name is available + select (config operator (pg_catalog.->> 'embedding_batch_schema'))::text into _embedding_batch_schema; + select (config operator (pg_catalog.->> 'embedding_batch_table'))::text into _embedding_batch_table; + select (config operator (pg_catalog.->> 'embedding_batch_chunks_table'))::text into _embedding_batch_chunks_table; + if pg_catalog.to_regclass(pg_catalog.format('%I.%I', _embedding_batch_schema, _embedding_batch_table)) is not null then + raise exception 'an object named %.% already exists. specify an alternate embedding_batch_table explicitly', queue_schema, queue_table; + end if; + + -- make sure embedding batch chunks table name is available + if pg_catalog.to_regclass(pg_catalog.format('%I.%I', _embedding_batch_schema, _embedding_batch_chunks_table)) is not null then + raise exception 'an object named %.% already exists. specify an alternate embedding_batch_chunks_table explicitly', queue_schema, queue_table; + end if; + -- ok when 'ollama' then -- ok diff --git a/projects/extension/sql/idempotent/013-vectorizer-api.sql b/projects/extension/sql/idempotent/013-vectorizer-api.sql index 37ec35ffd..e81415abb 100644 --- a/projects/extension/sql/idempotent/013-vectorizer-api.sql +++ b/projects/extension/sql/idempotent/013-vectorizer-api.sql @@ -29,9 +29,6 @@ create or replace function ai.create_vectorizer , queue_table pg_catalog.name default null , grant_to pg_catalog.name[] default ai.grant_to() , enqueue_existing pg_catalog.bool default true -, embedding_batch_schema pg_catalog.name default null -, embedding_batch_table pg_catalog.name default null -, embedding_batch_chunks_table pg_catalog.name default null ) returns pg_catalog.int4 as $func$ declare @@ -227,29 +224,15 @@ begin scheduling = pg_catalog.jsonb_insert(scheduling, array['job_id'], pg_catalog.to_jsonb(_job_id)); end if; - embedding_batch_schema = coalesce(embedding_batch_schema, 'ai'); - embedding_batch_table = coalesce(embedding_batch_table, pg_catalog.concat('_vectorizer_embedding_batches_', _vectorizer_id)); - embedding_batch_chunks_table = coalesce(embedding_batch_chunks_table, pg_catalog.concat('_vectorizer_embedding_batch_chunks_', _vectorizer_id)); - -- create batch embedding tables select (embedding operator (pg_catalog.->> 'implementation'))::text into _implementation; if _implementation = 'openai' then - -- make sure embedding batch table name is available - if pg_catalog.to_regclass(pg_catalog.format('%I.%I', embedding_batch_schema, embedding_batch_table)) is not null then - raise exception 'an object named %.% already exists. specify an alternate embedding_batch_table explicitly', queue_schema, queue_table; - end if; - - -- make sure embedding batch chunks table name is available - if pg_catalog.to_regclass(pg_catalog.format('%I.%I', embedding_batch_schema, embedding_batch_chunks_table)) is not null then - raise exception 'an object named %.% already exists. specify an alternate embedding_batch_chunks_table explicitly', queue_schema, queue_table; - end if; - perform ai._vectorizer_create_embedding_batches_table - (embedding_batch_schema + (embedding_batch_schema , embedding_batch_table , embedding_batch_chunks_table , grant_to - ); + ); end if; insert into ai.vectorizer @@ -286,9 +269,6 @@ begin , 'formatting', formatting , 'scheduling', scheduling , 'processing', processing - , 'embedding_batch_schema', embedding_batch_schema - , 'embedding_batch_table', embedding_batch_table - , 'embedding_batch_chunks_table', embedding_batch_chunks_table ) ); diff --git a/projects/pgai/pgai/vectorizer/embedders/openai.py b/projects/pgai/pgai/vectorizer/embedders/openai.py index abb05c41a..0352b546a 100644 --- a/projects/pgai/pgai/vectorizer/embedders/openai.py +++ b/projects/pgai/pgai/vectorizer/embedders/openai.py @@ -41,12 +41,20 @@ class OpenAI(ApiKeyMixin, BaseModel, Embedder): model (str): The name of the OpenAI model used for embeddings. dimensions (int | None): Optional dimensions for the embeddings. user (str | None): Optional user identifier for OpenAI API usage. + use_batch (bool): Whether to use OpenAI Batch API. + embedding_batch_schema (str | None): The schema where the embedding batches are stored. + embedding_batch_table (str | None): The table where the embedding batches are stored. + embedding_batch_chunks_table (str | None): The table where the embedding batch chunks are stored. """ implementation: Literal["openai"] model: str dimensions: int | None = None user: str | None = None + use_batch: bool = False + embedding_batch_schema: str | None = None + embedding_batch_table: str | None = None + embedding_batch_chunks_table: str | None = None @cached_property def _openai_dimensions(self) -> int | openai.NotGiven: diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 6f6aa36a3..439dd0ad0 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -75,9 +75,6 @@ class Config: processing: Processing settings such as batch size and concurrency. chunking: The chunking strategy. formatting: Formatting strategy to apply to the chunks. - embedding_batch_schema: The schema where the embedding batches are stored. - embedding_batch_table: The table where the embedding batches are stored. - embedding_batch_chunks_table: The table where the embedding batch chunks are stored. """ version: str @@ -87,9 +84,6 @@ class Config: LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter ) = Field(..., discriminator="implementation") formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation") - embedding_batch_schema: str | None - embedding_batch_table: str | None - embedding_batch_chunks_table: str | None @dataclass @@ -334,8 +328,8 @@ def fetch_batches_to_process_query(self) -> sql.Composed: SELECT openai_batch_id, output_file_id FROM {}.{} WHERE status not in('failed', 'processed', 'prepared') """).format( - self.vectorizer.config.embedding_batch_schema, - self.vectorizer.config.embedding_batch_table, + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, ) @cached_property @@ -349,8 +343,8 @@ def update_batch_embedding_query(self) -> sql.Composed: errors = %s WHERE openai_batch_id = %s """).format( - self.vectorizer.config.embedding_batch_schema, - self.vectorizer.config.embedding_batch_table, + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, ) @cached_property @@ -358,8 +352,8 @@ def update_batch_embedding_status_query(self) -> sql.Composed: return sql.SQL( "UPDATE {}.{} SET status = %s WHERE openai_batch_id = %s" ).format( - self.vectorizer.config.embedding_batch_schema, - self.vectorizer.config.embedding_batch_table, + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, ) @cached_property @@ -367,8 +361,8 @@ def fetch_chunks_for_batch_id_query(self) -> sql.Composed: return sql.SQL( "SELECT id, text FROM {}.{} WHERE embedding_batch_id = %s", ).format( - self.vectorizer.config.embedding_batch_schema, - self.vectorizer.config.embedding_batch_chunks_table, + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_chunks_table, ) @cached_property @@ -390,8 +384,8 @@ def insert_batch_embedding_query(self) -> sql.Composed: %s ) """).format( - self.vectorizer.config.embedding_batch_schema, - self.vectorizer.config.embedding_batch_table, + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, ) @cached_property @@ -407,8 +401,8 @@ def insert_batch_embedding_chunks_query(self) -> sql.Composed: %s ) """).format( - self.vectorizer.config.embedding_batch_schema, - self.vectorizer.config.embedding_batch_chunks_table, + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_chunks_table, ) def _pks_placeholders_tuples(self, items_count: int) -> sql.Composed: From cf869310deae62acadfc797846d8b68da73fe610 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 19 Dec 2024 15:15:13 +0100 Subject: [PATCH 20/23] chore: rename text to chunk to match store table --- projects/extension/sql/idempotent/016-openai-batch-api.sql | 2 +- projects/pgai/pgai/vectorizer/vectorizer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/extension/sql/idempotent/016-openai-batch-api.sql b/projects/extension/sql/idempotent/016-openai-batch-api.sql index 03f5a0211..f5c94ade9 100644 --- a/projects/extension/sql/idempotent/016-openai-batch-api.sql +++ b/projects/extension/sql/idempotent/016-openai-batch-api.sql @@ -42,7 +42,7 @@ begin ( $sql$create table %I.%I( id VARCHAR(255) PRIMARY KEY, embedding_batch_id VARCHAR(255) REFERENCES %I.%I (openai_batch_id), - text TEXT + chunk TEXT ))$sql$ , embedding_batch_schema , embedding_batch_chunks_table diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 439dd0ad0..bd57fa7cf 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -359,7 +359,7 @@ def update_batch_embedding_status_query(self) -> sql.Composed: @cached_property def fetch_chunks_for_batch_id_query(self) -> sql.Composed: return sql.SQL( - "SELECT id, text FROM {}.{} WHERE embedding_batch_id = %s", + "SELECT id, chunk FROM {}.{} WHERE embedding_batch_id = %s", ).format( self.vectorizer.config.embedding.embedding_batch_schema, self.vectorizer.config.embedding.embedding_batch_chunks_table, @@ -394,7 +394,7 @@ def insert_batch_embedding_chunks_query(self) -> sql.Composed: INSERT INTO {}.{} ( id, embedding_batch_id, - text + chunk ) VALUES ( %s, %s, From 916f7b2089e1dfd3b415ef5a454b74db26f89f3a Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 19 Dec 2024 15:23:38 +0100 Subject: [PATCH 21/23] feat: add total_attempts and next_attempt_after to openai batch table --- projects/extension/sql/idempotent/016-openai-batch-api.sql | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/projects/extension/sql/idempotent/016-openai-batch-api.sql b/projects/extension/sql/idempotent/016-openai-batch-api.sql index f5c94ade9..356569454 100644 --- a/projects/extension/sql/idempotent/016-openai-batch-api.sql +++ b/projects/extension/sql/idempotent/016-openai-batch-api.sql @@ -21,7 +21,9 @@ begin created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(), expires_at TIMESTAMP(0), completed_at TIMESTAMP(0), - failed_at TIMESTAMP(0) + failed_at TIMESTAMP(0), + next_attempt_after TIMESTAMPTZ, + total_attempts BIGINT NOT NULL DEFAULT 0 ))$sql$ , embedding_batch_schema , embedding_batch_table From ef4c382a58c802e79e68dcee00e0645a3326a6d2 Mon Sep 17 00:00:00 2001 From: kolaente Date: Thu, 19 Dec 2024 15:23:54 +0100 Subject: [PATCH 22/23] feat: make fetching queries concurrently safe --- projects/pgai/pgai/vectorizer/vectorizer.py | 35 ++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index bd57fa7cf..3329ece26 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -324,13 +324,34 @@ def insert_errors_query(self) -> sql.Composed: @cached_property def fetch_batches_to_process_query(self) -> sql.Composed: - return sql.SQL(""" - SELECT openai_batch_id, output_file_id FROM {}.{} - WHERE status not in('failed', 'processed', 'prepared') - """).format( - self.vectorizer.config.embedding.embedding_batch_schema, - self.vectorizer.config.embedding.embedding_batch_table, - ) + if not isinstance(self.vectorizer.config.embedding, OpenAI): + raise Exception("batch support is only available for openai") + + batch_schema = self.vectorizer.config.embedding.batch_schema + batch_table = self.vectorizer.config.embedding.batch_table + + return sql.SQL( + """ + WITH locked_rows AS ( + SELECT openai_batch_id + FROM {batch_table} + WHERE next_attempt_after is null or next_attempt_after < NOW() + ORDER BY created_at DESC + LIMIT 1 + FOR UPDATE SKIP LOCKED + ), + UPDATE + {batch_table} batches + SET + total_attempts = batches.total_attempts + 1, + next_attempt_after = %s + FRO + locked_rows l + WHERE + l.openai_batch_id = cfw.openai_batch_id + RETURNING l.openai_batch_id + """ + ).format(batch_table=sql.Identifier(batch_schema, batch_table)) @cached_property def update_batch_embedding_query(self) -> sql.Composed: From f96906dee339e2d7e53a5558e287bf416746598c Mon Sep 17 00:00:00 2001 From: kolaente Date: Mon, 23 Dec 2024 16:36:41 +0100 Subject: [PATCH 23/23] feat: make handling async embeddings more abstract --- .../sql/idempotent/016-openai-batch-api.sql | 4 +- .../pgai/pgai/vectorizer/embedders/openai.py | 103 +++++- projects/pgai/pgai/vectorizer/embeddings.py | 54 ++- projects/pgai/pgai/vectorizer/vectorizer.py | 311 +++++++++--------- 4 files changed, 305 insertions(+), 167 deletions(-) diff --git a/projects/extension/sql/idempotent/016-openai-batch-api.sql b/projects/extension/sql/idempotent/016-openai-batch-api.sql index 356569454..863863bb4 100644 --- a/projects/extension/sql/idempotent/016-openai-batch-api.sql +++ b/projects/extension/sql/idempotent/016-openai-batch-api.sql @@ -13,7 +13,7 @@ begin -- create the batches table select pg_catalog.format ( $sql$create table %I.%I( - openai_batch_id VARCHAR(255) PRIMARY KEY, + external_batch_id VARCHAR(255) PRIMARY KEY, input_file_id VARCHAR(255) NOT NULL, output_file_id VARCHAR(255), status VARCHAR(255) NOT NULL, @@ -43,7 +43,7 @@ begin select pg_catalog.format ( $sql$create table %I.%I( id VARCHAR(255) PRIMARY KEY, - embedding_batch_id VARCHAR(255) REFERENCES %I.%I (openai_batch_id), + embedding_batch_id VARCHAR(255) REFERENCES %I.%I (external_batch_id) ON DELETE CASCADE, chunk TEXT ))$sql$ , embedding_batch_schema diff --git a/projects/pgai/pgai/vectorizer/embedders/openai.py b/projects/pgai/pgai/vectorizer/embedders/openai.py index 0352b546a..cb5958fe4 100644 --- a/projects/pgai/pgai/vectorizer/embedders/openai.py +++ b/projects/pgai/pgai/vectorizer/embedders/openai.py @@ -10,6 +10,7 @@ from openai import resources from pydantic import BaseModel from typing_extensions import override +from psycopg import AsyncConnection from ..embeddings import ( ApiKeyMixin, @@ -23,6 +24,7 @@ Usage, logger, ) +from ..vectorizer import AsyncBatch TOKEN_CONTEXT_LENGTH_ERROR = "chunk exceeds model context length" @@ -68,9 +70,13 @@ def _openai_dimensions(self) -> int | openai.NotGiven: def _openai_user(self) -> str | openai.NotGiven: return self.user if self.user is not None else openai.NOT_GIVEN + @cached_property + def _client(self) -> resources.Client: + return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3) + @cached_property def _embedder(self) -> resources.AsyncEmbeddings: - return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3).embeddings + return self._client.embeddings @override def _max_chunks_per_batch(self) -> int: @@ -142,7 +148,7 @@ async def embed( async def create_and_submit_embedding_batch( self, documents: list[dict[str, Any]], - ) -> openai.types.Batch: + ) -> AsyncBatch: """ Creates a batch of embeddings using OpenAI's embeddings API as outlined in https://platform.openai.com/docs/guides/batch/batch-api?lang=python @@ -169,20 +175,25 @@ async def create_and_submit_embedding_batch( temp_file.close() - client = openai.OpenAI() # TODO there has to be a client already which I could use instead? - with open(temp_file.name, "rb") as file: - batch_input_file = client.files.create( + batch_input_file = self._client.files.create( file=file, purpose="batch", ) - return client.batches.create( + openai_batch = self._client.batches.create( input_file_id=batch_input_file.id, endpoint="/v1/embeddings", completion_window="24h", ) + batch = AsyncBatch() + batch.external_batch_id = openai_batch.id + batch.input_file_id = openai_batch.input_file_id + batch.status = openai_batch.status + + return batch + async def _filter_by_length_and_embed( self, model_token_length: int, encoded_documents: list[list[int]] ) -> Sequence[EmbeddingVector | ChunkEmbeddingError]: @@ -254,3 +265,83 @@ async def _encode(self, documents: list[str]) -> list[list[int]]: @cached_property def _encoder(self) -> tiktoken.Encoding: return tiktoken.encoding_for_model(self.model) + + def is_api_async(self) -> bool: + return self.use_batch + + async def fetch_async_embedding_status(self, batch: AsyncBatch) -> AsyncBatch: + openai_batch = self._client.batches.retrieve(batch.external_batch_id) + + batch.status = openai_batch.status + batch.completed_at = openai_batch.completed_at + batch.failed_at = openai_batch.failed_at + batch.errors = openai_batch.errors + + return batch + + async def process_async_embedding( + self, + conn: AsyncConnection, + batch: AsyncBatch, + ): + """ + Writes embeddings from an OpenAI batch embedding to the database. + + - Deletes existing embeddings for the items. + - Loads created embeddings from the batch. + - Writes created embeddings to the database. + - Logs any non-fatal errors encountered during embedding. + + Args: + conn (AsyncConnection): The database connection. + batch: The batch as stored in the queue table. + """ + openai_batch = self._client.batches.retrieve(batch.external_batch_id) + batch_file = self._client.files.content(openai_batch.output_file_id) + + batch_data = batch_file.text.strip().split("\n") + num_records = 0 + all_items = [] + all_records: list[EmbeddingRecord] = [] + + async with conn.cursor() as cursor: + await cursor.execute( + self.queries.fetch_chunks_for_batch_id_query + (batch.id,) + ) + embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()} + + for line in batch_data: + json_line = json.loads(line) + if "custom_id" in json_line and "response" in json_line: + + custom_id = json_line["custom_id"] + pk_names, document_id, chunk_seq = custom_id.split(":::") + embedding_data = json_line["response"]["body"]["data"][0]["embedding"] + + resolved_id = document_id.split(",") + resolved_pk = pk_names.split(",") + item = {pk: id_value + for pk, id_value in zip(resolved_pk, resolved_id, strict=False)} + item[self.vectorizer.config.chunking.chunk_column] = embedding_batch_chunks[custom_id] + + all_items.append(item) + all_records.append([ + resolved_id + + [chunk_seq, embedding_batch_chunks[custom_id]] + + [np.array(embedding_data)]]) + + await self._delete_embeddings(conn, all_items) + for records in all_records: + await self._copy_embeddings(conn, records) + + return num_records + + + async def finalize_async_embedding( + self, + batch: AsyncBatch, + ): + openai_batch = self._client.batches.retrieve(batch.external_batch_id) + await self._client.files.delete(openai_batch.input_file_id) + await self._client.files.delete(openai_batch.output_file_id) diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index a40f91f02..96d722d3e 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -3,11 +3,14 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass -from typing import Generic, TypeAlias, TypeVar +from typing import Generic, TypeAlias, TypeVar, Any +from psycopg import AsyncConnection import structlog from ddtrace import tracer +from .vectorizer import AsyncBatch + logger = structlog.get_logger() @@ -164,6 +167,55 @@ async def setup(self) -> None: # noqa: B027 empty on purpose Setup the embedder """ + @abstractmethod + def is_api_async(self) -> bool: + return False + + @abstractmethod + async def fetch_async_embedding_status(self, batch: AsyncBatch) -> AsyncBatch: + """ + Will receive a row from the batch embeddings queue table and should + check if the embedding has been processed and is ready to be stored. + + If it is ready, the status of the async batch should be set to "completed". + """ + + @abstractmethod + async def process_async_embedding( + self, + conn: AsyncConnection, + batch: AsyncBatch, + ): + """ + Writes embeddings from a batch embedding to the database. + + - Deletes existing embeddings for the items. + - Loads created embeddings from the batch. + - Writes created embeddings to the database. + - Logs any non-fatal errors encountered during embedding. + + Args: + conn (AsyncConnection): The database connection. + batch: The batch as retrieved from the database. + """ + + async def finalize_async_embedding( + self, + batch: AsyncBatch, + ): + """ + When the batch was processed, this method allows to clean up any + files from the external service. + """ + + @abstractmethod + async def create_and_submit_embedding_batch( + self, + documents: list[dict[str, Any]], + ) -> AsyncBatch: + """ + Receives a bunch of documents and creates a batch of documents for it with an external service. + """ class ApiKeyMixin: """ diff --git a/projects/pgai/pgai/vectorizer/vectorizer.py b/projects/pgai/pgai/vectorizer/vectorizer.py index 3329ece26..1fc4e5c2d 100644 --- a/projects/pgai/pgai/vectorizer/vectorizer.py +++ b/projects/pgai/pgai/vectorizer/vectorizer.py @@ -1,5 +1,4 @@ import asyncio -import json import os import threading import time @@ -10,7 +9,6 @@ from typing import Any, TypeAlias import numpy as np -import openai import psycopg import structlog from ddtrace import tracer @@ -86,6 +84,36 @@ class Config: formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation") +@dataclass +class AsyncBatch: + """ + Represents a record in the external batch table. + + Attributes: + external_batch_id (str): Primary key of the batch. + input_file_id (str): The ID of the input file. This is mandatory. + status (str): The current status of the batch. + errors (dict | None): Dictionary representing error details in JSONB format. + created_at (datetime): The timestamp when the record was created (defaults to current time). + expires_at (datetime | None): The optional expiration timestamp for the batch. + completed_at (datetime | None): The optional timestamp when the batch processing was completed. + failed_at (datetime | None): The timestamp when the batch processing failed (if applicable). + next_attempt_after (datetime | None): The timestamp when the batch can be retried next. + total_attempts (int): Count of the total number of attempts made to process this batch. + """ + + external_batch_id: str + input_file_id: str + status: str + errors: dict | None = None + created_at: datetime = Field(default_factory=datetime.now) + expires_at: datetime | None = None + completed_at: datetime | None = None + failed_at: datetime | None = None + next_attempt_after: datetime | None = None + total_attempts: int = 0 + + @dataclass class Vectorizer: """ @@ -324,16 +352,13 @@ def insert_errors_query(self) -> sql.Composed: @cached_property def fetch_batches_to_process_query(self) -> sql.Composed: - if not isinstance(self.vectorizer.config.embedding, OpenAI): - raise Exception("batch support is only available for openai") - batch_schema = self.vectorizer.config.embedding.batch_schema batch_table = self.vectorizer.config.embedding.batch_table return sql.SQL( """ WITH locked_rows AS ( - SELECT openai_batch_id + SELECT external_batch_id FROM {batch_table} WHERE next_attempt_after is null or next_attempt_after < NOW() ORDER BY created_at DESC @@ -348,8 +373,8 @@ def fetch_batches_to_process_query(self) -> sql.Composed: FRO locked_rows l WHERE - l.openai_batch_id = cfw.openai_batch_id - RETURNING l.openai_batch_id + l.external_batch_id = cfw.external_batch_id + RETURNING l.external_batch_id """ ).format(batch_table=sql.Identifier(batch_schema, batch_table)) @@ -360,9 +385,18 @@ def update_batch_embedding_query(self) -> sql.Composed: status = %s completed_at = %s, failed_at = %s, - output_file_id = %s, errors = %s - WHERE openai_batch_id = %s + WHERE external_batch_id = %s + """).format( + self.vectorizer.config.embedding.embedding_batch_schema, + self.vectorizer.config.embedding.embedding_batch_table, + ) + + @cached_property + def delete_batch_embedding_from_queue_query(self) -> sql.Composed: + return sql.SQL(""" + DELETE FROM {}.{} + WHERE external_batch_id = %s """).format( self.vectorizer.config.embedding.embedding_batch_schema, self.vectorizer.config.embedding.embedding_batch_table, @@ -371,7 +405,7 @@ def update_batch_embedding_query(self) -> sql.Composed: @cached_property def update_batch_embedding_status_query(self) -> sql.Composed: return sql.SQL( - "UPDATE {}.{} SET status = %s WHERE openai_batch_id = %s" + "UPDATE {}.{} SET status = %s WHERE external_batch_id = %s" ).format( self.vectorizer.config.embedding.embedding_batch_schema, self.vectorizer.config.embedding.embedding_batch_table, @@ -390,7 +424,7 @@ def fetch_chunks_for_batch_id_query(self) -> sql.Composed: def insert_batch_embedding_query(self) -> sql.Composed: return sql.SQL(""" INSERT INTO {}.{} ( - openai_batch_id, + external_batch_id, input_file_id, output_file_id, status, @@ -572,10 +606,11 @@ async def run(self) -> int: await register_vector_async(conn) await self.vectorizer.config.embedding.setup() while True: + if self.vectorizer.config.embedding.is_api_async(): + res = await self._process_async_embeddings(conn) + return res + if not self._continue_processing(loops, res): - await self._check_and_process_openai_batches(conn) - # TODO how can we run this only after hitting the rate limit of the normal openai batch embedding api? - await self._do_openai_batch(conn) return res items_processed = await self._do_batch(conn) if items_processed == 0: @@ -583,106 +618,126 @@ async def run(self) -> int: res += items_processed loops += 1 - @tracer.wrap() - async def _do_openai_batch(self, conn: AsyncConnection) -> int: + async def _process_async_embeddings(self, conn): + async with conn.transaction(): + await self.check_and_store_async_batches(conn) + await self.create_async_batches(conn) + + async def check_and_store_async_batches(self, conn: AsyncConnection): """ - Creates embeddings using openai's batch processing api. This allows to process - very large amounts of data faster than with the embeddings api, because the - batch api has vastly higher rate limits. + Checks if chunks submitted with create_async_batches completed and + stores them when they are completed. + + This function is only called when is_api_async returns true. Args: conn (AsyncConnection): The asynchronous database connection. """ + async with conn.cursor() as cursor: + with conn.transaction(): + await cursor.execute(self.queries.fetch_batches_to_process_query) - # TODO do nothing when openai is not configured - - try: - async with conn.transaction(): - items = await self._fetch_work(conn) + for batch_row in await cursor.fetchall(): - await logger.adebug(f"Items pulled from queue for openai batch embedding: {len(items)}") + batch = AsyncBatch(**batch_row) + batch = self.vectorizer.config.embedding.fetch_async_embedding_status(batch) - # Filter out items that were deleted from the source table. - # We use the first primary key column, since they can only - # be null if the LEFT JOIN didn't find a match. - items = [ - i - for i in items - if i[self.vectorizer.source_pk[0].attname] is not None - ] + with conn.transaction(): + await conn.execute(self.queries.update_batch_embedding_query, ( + batch.status, + datetime.fromtimestamp(batch.completed_at, timezone.utc) + if batch.completed_at else None, + datetime.fromtimestamp(batch.failed_at, timezone.utc) + if batch.failed_at else None, + Jsonb(batch.errors), + batch.external_batch_id, + )) - if len(items) == 0: - return 0 + # batch has been processed successfully by the external api, that means we can + # collect the results and store them in the database. + if batch.status == "completed": - created_batch, documents = await self._generate_embedding_batch(items) + with conn.transaction(): + await self.vectorizer.config.embedding.write_embeddings_from_batch(conn, batch) - await conn.execute(self.queries.insert_batch_embedding_query, ( - created_batch.id, - created_batch.input_file_id, - created_batch.output_file_id, - created_batch.status, - created_batch.errors, - datetime.fromtimestamp(created_batch.expires_at, timezone.utc), - )) + batch.status = "processed" - for doc in documents: - await conn.execute( - self.queries.insert_batch_embedding_chunks_query, - ( - doc["unique_full_chunk_id"], - created_batch.id, - doc["chunk"] - )) + await cursor.execute( + self.queries.update_batch_embedding_status_query, + ( + batch.status, + batch.external_batch_id, + )) - # TODO how to delete submitted entries from the queue? + if batch.status == "processed": + with conn.transaction(): + await self.vectorizer.config.embedding.finalize_async_embedding(batch) + await cursor.execute( + self.queries.delete_batch_embedding_from_queue_query, + ( + batch.external_batch_id, + )) - return len(items) - except Exception as e: - async with conn.transaction(): - await self._insert_vectorizer_error( - conn, - ( - self.vectorizer.id, - VECTORIZER_FAILED, - Jsonb({"error_reason": str(e)}), - ), - ) - raise e - @tracer.wrap() - async def _check_and_process_openai_batches(self, conn: AsyncConnection): - async with ( - conn.transaction(), - conn.cursor() as cursor, - ): - # TODO how can I get the client? There has to be one created already that I can use? - client = openai.OpenAI() - await cursor.execute(self.queries.fetch_batches_to_process_query) - for batch_row in await cursor.fetchall(): - batch = client.batches.retrieve(batch_row[0]) + async def create_async_batches(self, conn: AsyncConnection) -> int: + """ + Submits chunks for async embedding processing. + This allows to process very large amounts of data faster than with the + embeddings api, because batch apis usually have vastly higher rate limits. - await conn.execute(self.queries.update_batch_embedding_query, ( - batch.status, - datetime.fromtimestamp(batch.completed_at, timezone.utc) - if batch.completed_at else None, - datetime.fromtimestamp(batch.failed_at, timezone.utc) - if batch.failed_at else None, - batch.output_file_id, - Jsonb(batch.errors), - batch_row[0], - )) + This function is only called when is_api_async returns true. - # batch has been processed successfully in openai, that means we can - # collect the results and store them in the database. - if batch.status == "completed": - await self._write_embeddings_from_batch(conn, batch, client) + Args: + conn (AsyncConnection): The asynchronous database connection. + """ + try: + items = await self._fetch_work(conn) + + await logger.adebug(f"Items pulled from queue for batch embedding: {len(items)}") + + # Filter out items that were deleted from the source table. + # We use the first primary key column, since they can only + # be null if the LEFT JOIN didn't find a match. + items = [ + i + for i in items + if i[self.vectorizer.source_pk[0].attname] is not None + ] + + if len(items) == 0: + return 0 + + created_batch, documents = await self._generate_embedding_batch(items) + + await conn.execute(self.queries.insert_batch_embedding_query, ( + created_batch.external_id, + created_batch.input_file_id, + created_batch.output_file_id, + created_batch.status, + created_batch.errors, + datetime.fromtimestamp(created_batch.expires_at, timezone.utc), + )) + + for doc in documents: + await conn.execute( + self.queries.insert_batch_embedding_chunks_query, + ( + doc["unique_full_chunk_id"], + created_batch.id, + doc["chunk"] + )) - await cursor.execute( - self.queries.update_batch_embedding_status_query, - ( - "processed", - batch_row[0], - )) + return len(items) + except Exception as e: + await self._insert_vectorizer_error( + conn, + ( + self.vectorizer.id, + VECTORIZER_FAILED, + Jsonb({"error_reason": str(e)}), + ), + ) + raise e @tracer.wrap() async def _do_batch(self, conn: AsyncConnection) -> int: @@ -835,66 +890,6 @@ async def _embed_and_write(self, conn: AsyncConnection, items: list[SourceRow]): return len(records) - @tracer.wrap() - async def _write_embeddings_from_batch( - self, - conn: AsyncConnection, - batch: openai.types.Batch, - client: OpenAI, - ): - """ - Writes embeddings from an OpenAI batch embedding to the database. - - - Deletes existing embeddings for the items. - - Loads created embeddings from the batch. - - Writes created embeddings to the database. - - Logs any non-fatal errors encountered during embedding. - - Args: - conn (AsyncConnection): The database connection. - batch: The batch as retrieved from OpenAI's api. - client: The OpenAI client to use. - """ - batch_file = client.files.content(batch.output_file_id) - - batch_data = batch_file.text.strip().split("\n") - num_records = 0 - all_items = [] - all_records: list[EmbeddingRecord] = [] - - async with conn.cursor() as cursor: - await cursor.execute( - self.queries.fetch_chunks_for_batch_id_query - (batch.id,) - ) - embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()} - - for line in batch_data: - json_line = json.loads(line) - if "custom_id" in json_line and "response" in json_line: - - custom_id = json_line["custom_id"] - pk_names, document_id, chunk_seq = custom_id.split(":::") - embedding_data = json_line["response"]["body"]["data"][0]["embedding"] - - resolved_id = document_id.split(",") - resolved_pk = pk_names.split(",") - item = {pk: id_value - for pk, id_value in zip(resolved_pk, resolved_id, strict=False)} - item[self.vectorizer.config.chunking.chunk_column] = embedding_batch_chunks[custom_id] - - all_items.append(item) - all_records.append([ - resolved_id - + [chunk_seq, embedding_batch_chunks[custom_id]] - + [np.array(embedding_data)]]) - - await self._delete_embeddings(conn, all_items) - for records in all_records: - await self._copy_embeddings(conn, records) - - return num_records - async def _delete_embeddings(self, conn: AsyncConnection, items: list[SourceRow]): """ Deletes the embeddings for the given items from the target table. @@ -1005,7 +1000,7 @@ async def _generate_embeddings( async def _generate_embedding_batch( self, items: list[SourceRow] - ) -> tuple[openai.types.Batch, list[dict[str, Any]]]: + ) -> tuple[AsyncBatch, list[dict[str, Any]]]: documents: list[dict[str, Any]] = [] for item in items: pk = self._get_item_pk_values(item)