Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: start including basic aws emf metrics #272

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 72 additions & 27 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import numpy as np
import psycopg
import structlog
from aws_embedded_metrics import metric_scope
from aws_embedded_metrics.config import get_config
from aws_embedded_metrics.config.configuration import Configuration
from aws_embedded_metrics.logger.metrics_logger import MetricsLogger
from ddtrace import tracer
from pgvector.psycopg import register_vector_async # type: ignore
from psycopg import AsyncConnection, sql
Expand Down Expand Up @@ -77,7 +81,7 @@ class Config:
embedding: OpenAI | Ollama
processing: ProcessingDefault
chunking: (
LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter
LangChainCharacterTextSplitter | LangChainRecursiveCharacterTextSplitter
) = Field(..., discriminator="implementation")
formatting: PythonTemplate | ChunkValue = Field(..., discriminator="implementation")

Expand Down Expand Up @@ -254,9 +258,9 @@ def fetch_work_query(self) -> sql.Composed:
xs
for x in self.vectorizer.source_pk
for xs in [
sql.Literal(x.attname),
sql.Identifier(x.attname),
]
sql.Literal(x.attname),
sql.Identifier(x.attname),
]
]
),
delete_join_predicates=sql.SQL(" AND ").join(
Expand Down Expand Up @@ -420,6 +424,10 @@ async def print_stats(self):
)


def _metrics_enabled(metrics_config) -> bool:
return metrics_config is not None and not metrics_config.disable_metric_extraction


class Worker:
"""
Responsible for processing items from the work queue and generating embeddings.
Expand All @@ -439,24 +447,28 @@ class Worker:
_continue_processing: Callable[[int, int], bool]

def __init__(
self,
db_url: str,
vectorizer: Vectorizer,
continue_processing: None | Callable[[int, int], bool] = None,
self,
db_url: str,
vectorizer: Vectorizer,
continue_processing: None | Callable[[int, int], bool] = None,
):
self.db_url = db_url
self.vectorizer = vectorizer
self.queries = VectorizerQueryBuilder(vectorizer)
self._continue_processing = continue_processing or (lambda _loops, _res: True)

async def run(self) -> int:
@metric_scope
async def run(self, metrics:MetricsLogger) -> int:
"""
Embedding loop. Continuously fetches tasks from the work queue and
processes them within the context of a transaction.

Returns:
int: The number of tasks processed from the work queue.
"""
metrics_config = None
if metrics is not None:
metrics_config = get_config()
res = 0
loops = 0

Expand All @@ -465,14 +477,16 @@ async def run(self) -> int:
while True:
if not self._continue_processing(loops, res):
return res
items_processed = await self._do_batch(conn)
items_processed = await self._do_batch(conn, metrics, metrics_config)
if items_processed == 0:
return res
res += items_processed
loops += 1

@tracer.wrap()
async def _do_batch(self, conn: AsyncConnection) -> int:
@metric_scope
async def _do_batch(self, conn: AsyncConnection, metrics: MetricsLogger,
metrics_config: Configuration) -> int:
"""
Processes a batch of tasks. Fetches items from the queue, filters out
deleted items, generates embeddings, and writes them to the database.
Expand Down Expand Up @@ -508,9 +522,17 @@ async def _do_batch(self, conn: AsyncConnection) -> int:

num_chunks = await self._embed_and_write(conn, items)

time_spent = time.perf_counter() - start_time
processing_stats.add_request_time(
time.perf_counter() - start_time, num_chunks
time_spent, num_chunks
)
if _metrics_enabled(metrics_config):
metrics.put_dimensions(
{"provider": self.vectorizer.config.embedding.implementation}
)
metrics.put_metric("embedding_generation_batch_total_duration",
time_spent*1000, "Milliseconds")

await processing_stats.print_stats()

return len(items)
Expand Down Expand Up @@ -595,7 +617,9 @@ async def _get_queue_table_oid(self, conn: AsyncConnection) -> int:
return self._queue_table_oid

@tracer.wrap()
async def _embed_and_write(self, conn: AsyncConnection, items: list[SourceRow]):
async def _embed_and_write(
self, conn: AsyncConnection, items: list[SourceRow], metrics: MetricsLogger,
metrics_config:Configuration):
"""
Embeds the items and writes them to the database.

Expand All @@ -612,9 +636,11 @@ async def _embed_and_write(self, conn: AsyncConnection, items: list[SourceRow]):
Returns:
int: The number of records written to the database.
"""

await self._delete_embeddings(conn, items)
records, errors = await self._generate_embeddings(items)
records, errors = await self._generate_embeddings(
items, metrics, metrics_config
)

# await self._insert_embeddings(conn, records)
await self._copy_embeddings(conn, records)
if errors:
Expand All @@ -636,9 +662,9 @@ async def _delete_embeddings(self, conn: AsyncConnection, items: list[SourceRow]

@tracer.wrap()
async def _copy_embeddings(
self,
conn: AsyncConnection,
records: list[EmbeddingRecord],
self,
conn: AsyncConnection,
records: list[EmbeddingRecord],
):
"""
Inserts embeddings into the embedding table using COPY FROM STDIN WITH
Expand All @@ -657,9 +683,9 @@ async def _copy_embeddings(
await copy.write_row(record)

async def _insert_vectorizer_errors(
self,
conn: AsyncConnection,
records: list[VectorizerErrorRecord],
self,
conn: AsyncConnection,
records: list[VectorizerErrorRecord],
):
"""
Inserts vectorizer errors into the errors table.
Expand All @@ -672,9 +698,9 @@ async def _insert_vectorizer_errors(
await cursor.executemany(self.queries.insert_errors_query, records)

async def _insert_vectorizer_error(
self,
conn: AsyncConnection,
record: VectorizerErrorRecord,
self,
conn: AsyncConnection,
record: VectorizerErrorRecord,
):
"""
Inserts a single vectorizer error into the errors table.
Expand All @@ -689,8 +715,10 @@ async def _insert_vectorizer_error(
def _get_item_pk_values(self, item: SourceRow) -> list[Any]:
return [item[pk] for pk in self.queries.pk_attnames]

@metric_scope
async def _generate_embeddings(
self, items: list[SourceRow]
self, items: list[SourceRow], metrics: MetricsLogger,
metrics_config: Configuration
) -> tuple[list[EmbeddingRecord], list[VectorizerErrorRecord]]:
"""
Generates the embeddings for the given items.
Expand All @@ -713,16 +741,33 @@ async def _generate_embeddings(
documents.append(formatted)

try:
start_time = time.perf_counter()
embeddings = await self.vectorizer.config.embedding.embed(documents)
time_spent = time.perf_counter() - start_time
if _metrics_enabled(metrics_config):
metrics.put_dimensions(
{"provider": self.vectorizer.config.embedding.implementation}
)
metrics.put_metric("embeddings_generation_success", 1, "Count")
metrics.put_metric(
"embeddings_generation_duration", time_spent*1000, "Milliseconds"
)

except Exception as e:
if _metrics_enabled(metrics_config):
metrics.put_dimensions(
{"provider": self.vectorizer.config.embedding.implementation}
)
metrics.put_metric("embeddings_generation_error", 1, "Count")
metrics.set_property("reason", str(e))
raise EmbeddingProviderError() from e

assert len(embeddings) == len(records_without_embeddings)

records: list[EmbeddingRecord] = []
errors: list[VectorizerErrorRecord] = []
for record, embedding in zip(
records_without_embeddings, embeddings, strict=True
records_without_embeddings, embeddings, strict=True
):
if isinstance(embedding, ChunkEmbeddingError):
errors.append(self._vectorizer_error_record(record, embedding))
Expand All @@ -731,7 +776,7 @@ async def _generate_embeddings(
return records, errors

def _vectorizer_error_record(
self, record: EmbeddingRecord, chunk_error: ChunkEmbeddingError
self, record: EmbeddingRecord, chunk_error: ChunkEmbeddingError
) -> VectorizerErrorRecord:
return (
self.vectorizer.id,
Expand Down
3 changes: 2 additions & 1 deletion projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"typing_extensions>=4.0,<5.0",
"datadog_lambda>=6.9,<7.0",
"pytimeparse>=1.1,<2.0",
"aws-embedded-metrics>=3.2.0",
]
classifiers = [
"License :: OSI Approved :: PostgreSQL License",
Expand Down Expand Up @@ -107,4 +108,4 @@ dev-dependencies = [
"testcontainers==4.8.1",
"build==1.2.2.post1",
"twine==5.1.1",
]
]
1 change: 1 addition & 0 deletions projects/pgai/tests/vectorizer/test_vectorizer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def test_process_vectorizer(
array_fill(0, ARRAY[1536])::vector)
""")
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
monkeypatch.setenv("AWS_EMF_DISABLE_METRIC_EXTRACTION", "true")

# When running the worker with cassette matching original test params
cassette = (
Expand Down
Loading
Loading