Skip to content

Commit

Permalink
fix: handle multiple embedding events
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Dec 12, 2024
1 parent 4e5d219 commit 583e77c
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dynamic = ["version"]
description = "OpenInference LlamaIndex Instrumentation"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.8, <3.13"
requires-python = ">=3.8, <3.14"
authors = [
{ name = "OpenInference Authors", email = "[email protected]" },
]
Expand All @@ -23,6 +23,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = [
"opentelemetry-api",
Expand Down Expand Up @@ -66,6 +67,7 @@ packages = ["src/openinference"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = [
"tests",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import weakref
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum, auto
from functools import singledispatch, singledispatchmethod
Expand All @@ -15,6 +16,7 @@
TYPE_CHECKING,
Any,
AsyncGenerator,
DefaultDict,
Dict,
Generator,
Iterable,
Expand Down Expand Up @@ -183,6 +185,7 @@ def __init__(
self._attributes = {}
self._end_time = None
self._last_updated_at = time()
self._list_attr_len: DefaultDict[str, int] = defaultdict(int)

def __setitem__(self, key: str, value: AttributeValue) -> None:
self._attributes[key] = value
Expand Down Expand Up @@ -362,9 +365,12 @@ def _(self, event: EmbeddingStartEvent) -> None:

@_process_event.register
def _(self, event: EmbeddingEndEvent) -> None:
for i, (text, vector) in enumerate(zip(event.chunks, event.embeddings)):
i = self._list_attr_len[EMBEDDING_EMBEDDINGS]
for text, vector in zip(event.chunks, event.embeddings):
self[f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_TEXT}"] = text
self[f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_VECTOR}"] = vector
i += 1
self._list_attr_len[EMBEDDING_EMBEDDINGS] = i

@_process_event.register
def _(self, event: StreamChatStartEvent) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter


@pytest.fixture
def in_memory_span_exporter() -> InMemorySpanExporter:
return InMemorySpanExporter()


@pytest.fixture
def tracer_provider(
in_memory_span_exporter: InMemorySpanExporter,
) -> TracerProvider:
tracer_provider = TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter))
return tracer_provider
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from itertools import product
from typing import Iterator

import pytest
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.embedding import EmbeddingEndEvent
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes

dispatcher = get_dispatcher(__name__)


@dispatcher.span # type: ignore[misc,unused-ignore]
def foo(m: int, n: int) -> None:
for i in range(m):
chunks = [f"{i}-{j}" for j in range(n)]
embeddings = [list(map(float, [i, j])) for j in range(n)]
dispatcher.event(EmbeddingEndEvent(chunks=chunks, embeddings=embeddings))


async def test_multiple_embedding_events(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
m, n = 3, 2
foo(m, n)
span = in_memory_span_exporter.get_finished_spans()[0]
assert span.attributes
for k, (i, j) in enumerate(product(range(m), range(n))):
text, vector = f"{i}-{j}", tuple(map(float, [i, j]))
assert span.attributes[f"{EMBEDDING_EMBEDDINGS}.{k}.{EMBEDDING_TEXT}"] == text
assert span.attributes[f"{EMBEDDING_EMBEDDINGS}.{k}.{EMBEDDING_VECTOR}"] == vector


@pytest.fixture(autouse=True)
def instrument(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
) -> Iterator[None]:
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
yield
LlamaIndexInstrumentor().uninstrument()


EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS
EMBEDDING_TEXT = EmbeddingAttributes.EMBEDDING_TEXT
EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR
2 changes: 1 addition & 1 deletion python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ envlist =
py3{9,12}-ci-{mistralai,mistralai-latest}
py3{8,12}-ci-{openai,openai-latest}
py3{8,12}-ci-{vertexai,vertexai-latest}
py3{8,12}-ci-{llama_index,llama_index-latest}
py3{8,13}-ci-{llama_index,llama_index-latest}
py3{9,12}-ci-{dspy,dspy-latest}
py3{9,12}-ci-{langchain,langchain-latest}
; py3{9,12}-ci-{guardrails,guardrails-latest}
Expand Down

0 comments on commit 583e77c

Please sign in to comment.