diff --git a/.github/workflows/python-CI.yaml b/.github/workflows/python-CI.yaml new file mode 100644 index 000000000..0f2077d92 --- /dev/null +++ b/.github/workflows/python-CI.yaml @@ -0,0 +1,26 @@ +name: Python CI + +on: + push: + branches: [main] + pull_request: + paths: + - "python/**" + +defaults: + run: + working-directory: ./python + +jobs: + ci: + name: CI Python + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: | + 3.8 + 3.11 + - run: pip install tox==4.11.4 + - run: tox run-parallel diff --git a/README.md b/README.md index 7749faa5a..8f46408fd 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,10 @@ OpenInference provides a set of instrumentations for popular machine learning SD ## Python -| Package | Description | -| --------------------------------------------------------------------------------------------- | --------------------------------------------- | -| [`openinference-semantic-conventions`](./python/openinference-semantic-conventions/README.md) | Semantic conventions for tracing of LLM Apps. | +| Package | Description | +|--------------------------------------------------------------------------------------------------------------------|-----------------------------------------------| +| [`openinference-semantic-conventions`](./python/openinference-semantic-conventions/README.md) | Semantic conventions for tracing of LLM Apps. | +| [`openinference-instrumentation-openai`](./python/instrumentation/openinference-instrumentation-openai/README.rst) | OpenInference Instrumentation for OpenAI SDK. | ## JavaScript diff --git a/python/dev-requirements.txt b/python/dev-requirements.txt new file mode 100644 index 000000000..4abb5b680 --- /dev/null +++ b/python/dev-requirements.txt @@ -0,0 +1,3 @@ +pytest == 7.4.4 +ruff == 0.1.11 +mypy == 1.8.0 diff --git a/python/instrumentation/openinference-instrumentation-openai/examples/chat_completions.py b/python/instrumentation/openinference-instrumentation-openai/examples/chat_completions.py new file mode 100644 index 000000000..39c845f70 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/examples/chat_completions.py @@ -0,0 +1,25 @@ +import openai +from openinference.instrumentation.openai import OpenAIInstrumentor +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + +resource = Resource(attributes={}) +tracer_provider = trace_sdk.TracerProvider(resource=resource) +span_exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:6006/v1/traces") +span_processor = SimpleSpanProcessor(span_exporter=span_exporter) +tracer_provider.add_span_processor(span_processor=span_processor) +trace_api.set_tracer_provider(tracer_provider=tracer_provider) + +OpenAIInstrumentor().instrument() + + +if __name__ == "__main__": + response = openai.OpenAI().chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Write a haiku."}], + max_tokens=20, + ) + print(response.choices[0].message.content) diff --git a/python/instrumentation/openinference-instrumentation-openai/examples/chat_completions_async_stream.py b/python/instrumentation/openinference-instrumentation-openai/examples/chat_completions_async_stream.py new file mode 100644 index 000000000..1216e9ff3 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/examples/chat_completions_async_stream.py @@ -0,0 +1,37 @@ +import asyncio + +import openai +from openinference.instrumentation.openai import OpenAIInstrumentor +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + +resource = Resource(attributes={}) +tracer_provider = trace_sdk.TracerProvider(resource=resource) +span_exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:6006/v1/traces") +span_processor = SimpleSpanProcessor(span_exporter=span_exporter) +tracer_provider.add_span_processor(span_processor=span_processor) +trace_api.set_tracer_provider(tracer_provider=tracer_provider) + +OpenAIInstrumentor().instrument() + + +async def chat_completions(**kwargs): + client = openai.AsyncOpenAI() + async for chunk in await client.chat.completions.create(**kwargs): + if content := chunk.choices[0].delta.content: + print(content, end="") + print() + + +if __name__ == "__main__": + asyncio.run( + chat_completions( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Write a haiku."}], + max_tokens=20, + stream=True, + ), + ) diff --git a/python/instrumentation/openinference-instrumentation-openai/examples/embeddings.py b/python/instrumentation/openinference-instrumentation-openai/examples/embeddings.py new file mode 100644 index 000000000..812d721e0 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/examples/embeddings.py @@ -0,0 +1,24 @@ +import openai +from openinference.instrumentation.openai import OpenAIInstrumentor +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + +resource = Resource(attributes={}) +tracer_provider = trace_sdk.TracerProvider(resource=resource) +span_exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:6006/v1/traces") +span_processor = SimpleSpanProcessor(span_exporter=span_exporter) +tracer_provider.add_span_processor(span_processor=span_processor) +trace_api.set_tracer_provider(tracer_provider=tracer_provider) + +OpenAIInstrumentor().instrument() + + +if __name__ == "__main__": + response = openai.OpenAI().embeddings.create( + model="text-embedding-ada-002", + input="hello world", + ) + print(response.data[0].embedding) diff --git a/python/instrumentation/openinference-instrumentation-openai/examples/with_httpx_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/examples/with_httpx_instrumentor.py new file mode 100644 index 000000000..9ba20b9fd --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/examples/with_httpx_instrumentor.py @@ -0,0 +1,29 @@ +from importlib import import_module + +from openinference.instrumentation.openai import OpenAIInstrumentor +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + +resource = Resource(attributes={}) +tracer_provider = trace_sdk.TracerProvider(resource=resource) +span_exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:6006/v1/traces") +span_processor = SimpleSpanProcessor(span_exporter=span_exporter) +tracer_provider.add_span_processor(span_processor=span_processor) +trace_api.set_tracer_provider(tracer_provider=tracer_provider) + +HTTPXClientInstrumentor().instrument() +OpenAIInstrumentor().instrument() + + +if __name__ == "__main__": + openai = import_module("openai") + response = openai.OpenAI().chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Write a haiku."}], + max_tokens=20, + ) + print(response.choices[0].message.content) diff --git a/python/instrumentation/openinference-instrumentation-openai/pyproject.toml b/python/instrumentation/openinference-instrumentation-openai/pyproject.toml index 63784da91..33add988d 100644 --- a/python/instrumentation/openinference-instrumentation-openai/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-openai/pyproject.toml @@ -10,30 +10,37 @@ readme = "README.rst" license = "Apache-2.0" requires-python = ">=3.8, <3.12" authors = [ - { name = "OpenInference Authors", email = "oss@arize.com" }, + { name = "OpenInference Authors", email = "oss@arize.com" }, ] classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ] dependencies = [ - "opentelemetry-api", - "opentelemetry-instrumentation", - "opentelemetry-semantic-conventions", - "openinference-semantic-conventions", - "wrapt", + "opentelemetry-api", + "opentelemetry-instrumentation", + "opentelemetry-semantic-conventions", + "openinference-semantic-conventions", + "wrapt", ] [project.optional-dependencies] +instruments = [ + "openai >= 1.0.0", +] test = [ - "openai == 1.0.0", + "openai == 1.0.0", + "opentelemetry-sdk", + "opentelemetry-instrumentation-httpx", + "respx", + "numpy", ] [project.urls] @@ -44,8 +51,8 @@ path = "src/openinference/instrumentation/openai/version.py" [tool.hatch.build.targets.sdist] include = [ - "/src", - "/tests", + "/src", + "/tests", ] [tool.hatch.build.targets.wheel] diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/__init__.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/__init__.py index e69de29bb..e3b4faf8c 100644 --- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/__init__.py +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/__init__.py @@ -0,0 +1,55 @@ +import logging +from importlib import import_module +from typing import Any, Collection + +from openinference.instrumentation.openai._request import ( + _AsyncRequest, + _Request, +) +from openinference.instrumentation.openai.package import _instruments +from openinference.instrumentation.openai.version import __version__ +from opentelemetry import trace as trace_api +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore +from wrapt import wrap_function_wrapper + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +_MODULE = "openai" + + +class OpenAIInstrumentor(BaseInstrumentor): # type: ignore + """ + An instrumentor for openai + """ + + __slots__ = ( + "_original_request", + "_original_async_request", + ) + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs: Any) -> None: + if not (tracer_provider := kwargs.get("tracer_provider")): + tracer_provider = trace_api.get_tracer_provider() + tracer = trace_api.get_tracer(__name__, __version__, tracer_provider) + openai = import_module(_MODULE) + self._original_request = openai.OpenAI.request + self._original_async_request = openai.AsyncOpenAI.request + wrap_function_wrapper( + module=_MODULE, + name="OpenAI.request", + wrapper=_Request(tracer=tracer, openai=openai), + ) + wrap_function_wrapper( + module=_MODULE, + name="AsyncOpenAI.request", + wrapper=_AsyncRequest(tracer=tracer, openai=openai), + ) + + def _uninstrument(self, **kwargs: Any) -> None: + openai = import_module(_MODULE) + openai.OpenAI.request = self._original_request + openai.AsyncOpenAI.request = self._original_async_request diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request.py new file mode 100644 index 000000000..7e16a6ae1 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request.py @@ -0,0 +1,379 @@ +import logging +from abc import ABC +from contextlib import contextmanager +from types import ModuleType +from typing import ( + Any, + Awaitable, + Callable, + Iterable, + Iterator, + Mapping, + Tuple, +) + +from openinference.instrumentation.openai._request_attributes_extractor import ( + _RequestAttributesExtractor, +) +from openinference.instrumentation.openai._response_accumulator import ( + _ChatCompletionAccumulator, + _CompletionAccumulator, +) +from openinference.instrumentation.openai._response_attributes_extractor import ( + _ResponseAttributesExtractor, +) +from openinference.instrumentation.openai._stream import ( + _ResponseAccumulator, + _Stream, +) +from openinference.instrumentation.openai._utils import ( + _as_input_attributes, + _as_output_attributes, + _finish_tracing, + _io_value_and_type, +) +from openinference.instrumentation.openai._with_span import _WithSpan +from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes +from opentelemetry import context as context_api +from opentelemetry import trace as trace_api +from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY +from opentelemetry.trace import INVALID_SPAN +from opentelemetry.util.types import AttributeValue +from typing_extensions import TypeAlias + +__all__ = ( + "_Request", + "_AsyncRequest", +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _WithTracer(ABC): + def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._tracer = tracer + + @contextmanager + def _start_as_current_span( + self, + span_name: str, + attributes: Iterable[Tuple[str, AttributeValue]], + extra_attributes: Iterable[Tuple[str, AttributeValue]], + ) -> Iterator[_WithSpan]: + # Because OTEL has a default limit of 128 attributes, we split our attributes into + # two tiers, where the addition of "extra_attributes" is deferred until the end + # and only after the "attributes" are added. + try: + span = self._tracer.start_span(name=span_name, attributes=dict(attributes)) + except Exception: + logger.exception("Failed to start span") + span = INVALID_SPAN + with trace_api.use_span( + span, + end_on_exit=False, + record_exception=False, + set_status_on_exception=False, + ) as span: + yield _WithSpan(span=span, extra_attributes=dict(extra_attributes)) + + +_RequestParameters: TypeAlias = Mapping[str, Any] + + +class _WithOpenAI(ABC): + __slots__ = ( + "_openai", + "_stream_types", + "_request_attributes_extractor", + "_response_attributes_extractor", + "_response_accumulator_factories", + ) + + def __init__(self, openai: ModuleType, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._openai = openai + self._stream_types = (openai.Stream, openai.AsyncStream) + self._request_attributes_extractor = _RequestAttributesExtractor(openai=openai) + self._response_attributes_extractor = _ResponseAttributesExtractor(openai=openai) + self._response_accumulator_factories: Mapping[ + type, Callable[[_RequestParameters], _ResponseAccumulator] + ] = { + openai.types.Completion: lambda request_parameters: _CompletionAccumulator( + request_parameters=request_parameters, + completion_type=openai.types.Completion, + response_attributes_extractor=self._response_attributes_extractor, + ), + openai.types.chat.ChatCompletion: lambda request_parameters: _ChatCompletionAccumulator( + request_parameters=request_parameters, + chat_completion_type=openai.types.chat.ChatCompletion, + response_attributes_extractor=self._response_attributes_extractor, + ), + } + + def _get_span_kind(self, cast_to: type) -> str: + return ( + OpenInferenceSpanKindValues.EMBEDDING.value + if cast_to is self._openai.types.CreateEmbeddingResponse + else OpenInferenceSpanKindValues.LLM.value + ) + + def _get_attributes_from_request( + self, + cast_to: type, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + yield SpanAttributes.OPENINFERENCE_SPAN_KIND, self._get_span_kind(cast_to=cast_to) + try: + yield from _as_input_attributes(_io_value_and_type(request_parameters)) + except Exception: + logger.exception( + f"Failed to get input attributes from request parameters of " + f"type {type(request_parameters)}" + ) + + def _get_extra_attributes_from_request( + self, + cast_to: type, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + # Secondary attributes should be added after input and output to ensure + # that input and output are not dropped if there are too many attributes. + try: + yield from self._request_attributes_extractor.get_attributes_from_request( + cast_to=cast_to, + request_parameters=request_parameters, + ) + except Exception: + logger.exception( + f"Failed to get extra attributes from request options of " + f"type {type(request_parameters)}" + ) + + def _is_streaming(self, response: Any) -> bool: + return isinstance(response, self._stream_types) + + def _finalize_response( + self, + response: Any, + with_span: _WithSpan, + cast_to: type, + request_parameters: Mapping[str, Any], + ) -> Any: + """ + Monkey-patch the response object to trace the stream, or finish tracing if the response is + not a stream. + """ + + if hasattr(response, "parse") and callable(response.parse): + # `.request()` may be called under `.with_raw_response` and it's necessary to call + # `.parse()` to get back the usual response types. + # E.g. see https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/_base_client.py#L518 # noqa: E501 + try: + response.parse() + except Exception: + logger.exception(f"Failed to parse response of type {type(response)}") + if ( + self._is_streaming(response) + or hasattr( + # FIXME: Ideally we should not rely on a private attribute (but it may be impossible). + # The assumption here is that calling `.parse()` stores the stream object in `._parsed` + # and calling `.parse()` again will not overwrite the monkey-patched version. + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/_response.py#L65 # noqa: E501 + response, + "_parsed", + ) + # Note that we must have called `.parse()` beforehand, otherwise `._parsed` is None. + and self._is_streaming(response._parsed) + ): + # For streaming, we need an (optional) accumulator to process each chunk iteration. + try: + response_accumulator_factory = self._response_accumulator_factories.get(cast_to) + response_accumulator = ( + response_accumulator_factory(request_parameters) + if response_accumulator_factory + else None + ) + except Exception: + # Note that cast_to may not be hashable. + logger.exception(f"Failed to get response accumulator for {cast_to}") + response_accumulator = None + if hasattr(response, "_parsed") and self._is_streaming(parsed := response._parsed): + # Monkey-patch a private attribute assumed to be caching the output of `.parse()`. + response._parsed = _Stream( + stream=parsed, + with_span=with_span, + response_accumulator=response_accumulator, + ) + return response + return _Stream( + stream=response, + with_span=with_span, + response_accumulator=response_accumulator, + ) + _finish_tracing( + status_code=trace_api.StatusCode.OK, + with_span=with_span, + has_attributes=_ResponseAttributes( + request_parameters=request_parameters, + response=response, + response_attributes_extractor=self._response_attributes_extractor, + ), + ) + return response + + +class _Request(_WithTracer, _WithOpenAI): + def __call__( + self, + wrapped: Callable[..., Any], + instance: Any, + args: Tuple[type, Any], + kwargs: Mapping[str, Any], + ) -> Any: + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return wrapped(*args, **kwargs) + try: + cast_to, request_parameters = _parse_request_args(args) + # E.g. cast_to = openai.types.chat.ChatCompletion => span_name = "ChatCompletion" + span_name: str = cast_to.__name__.split(".")[-1] + except Exception: + logger.exception("Failed to parse request args") + return wrapped(*args, **kwargs) + with self._start_as_current_span( + span_name=span_name, + attributes=self._get_attributes_from_request( + cast_to=cast_to, + request_parameters=request_parameters, + ), + extra_attributes=self._get_extra_attributes_from_request( + cast_to=cast_to, + request_parameters=request_parameters, + ), + ) as with_span: + try: + response = wrapped(*args, **kwargs) + except Exception as exception: + status_code = trace_api.StatusCode.ERROR + with_span.record_exception(exception) + with_span.finish_tracing(status_code=status_code) + raise + try: + response = self._finalize_response( + response=response, + with_span=with_span, + cast_to=cast_to, + request_parameters=request_parameters, + ) + except Exception: + logger.exception(f"Failed to finalize response of type {type(response)}") + with_span.finish_tracing(status_code=None) + return response + + +class _AsyncRequest(_WithTracer, _WithOpenAI): + async def __call__( + self, + wrapped: Callable[..., Awaitable[Any]], + instance: Any, + args: Tuple[type, Any], + kwargs: Mapping[str, Any], + ) -> Any: + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return await wrapped(*args, **kwargs) + try: + cast_to, request_parameters = _parse_request_args(args) + # E.g. cast_to = openai.types.chat.ChatCompletion => span_name = "ChatCompletion" + span_name: str = cast_to.__name__.split(".")[-1] + except Exception: + logger.exception("Failed to parse request args") + return await wrapped(*args, **kwargs) + with self._start_as_current_span( + span_name=span_name, + attributes=self._get_attributes_from_request( + cast_to=cast_to, + request_parameters=request_parameters, + ), + extra_attributes=self._get_extra_attributes_from_request( + cast_to=cast_to, + request_parameters=request_parameters, + ), + ) as with_span: + try: + response = await wrapped(*args, **kwargs) + except Exception as exception: + status_code = trace_api.StatusCode.ERROR + with_span.record_exception(exception) + with_span.finish_tracing(status_code=status_code) + raise + try: + response = self._finalize_response( + response=response, + with_span=with_span, + cast_to=cast_to, + request_parameters=request_parameters, + ) + except Exception: + logger.exception(f"Failed to finalize response of type {type(response)}") + with_span.finish_tracing(status_code=None) + return response + + +def _parse_request_args(args: Tuple[type, Any]) -> Tuple[type, Mapping[str, Any]]: + # We don't use `signature(request).bind()` because `request` could have been monkey-patched + # (incorrectly) by others and the signature at runtime may not match the original. + # The targeted signature of `request` is here: + # https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/_base_client.py#L846-L847 # noqa: E501 + cast_to: type = args[0] + request_parameters: Mapping[str, Any] = ( + json_data + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/_models.py#L427 # noqa: E501 + if hasattr(args[1], "json_data") and isinstance(json_data := args[1].json_data, Mapping) + else {} + ) + # FIXME: Because request parameters is just a Mapping, it can contain any value as long as it + # serializes correctly in an HTTP request body. For example, Enum values may be present if a + # third-party library puts them there. Enums can turn into their intended string values via + # `json.dumps` when the final HTTP request body is serialized, but can pose problems when we + # try to extract attributes. However, this round-trip seems expensive, so we opted to treat + # only the Enums that we know about: e.g. message role sometimes can be an Enum, so we will + # convert it only when it's encountered. + # try: + # request_parameters = json.loads(json.dumps(request_parameters)) + # except Exception: + # pass + return cast_to, request_parameters + + +class _ResponseAttributes: + __slots__ = ( + "_response", + "_request_parameters", + "_response_attributes_extractor", + ) + + def __init__( + self, + response: Any, + request_parameters: Mapping[str, Any], + response_attributes_extractor: _ResponseAttributesExtractor, + ) -> None: + if hasattr(response, "parse") and callable(response.parse): + # E.g. see https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/_base_client.py#L518 # noqa: E501 + try: + response = response.parse() + except Exception: + logger.exception(f"Failed to parse response of type {type(response)}") + self._request_parameters = request_parameters + self._response = response + self._response_attributes_extractor = response_attributes_extractor + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + yield from _as_output_attributes(_io_value_and_type(self._response)) + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + yield from self._response_attributes_extractor.get_attributes_from_response( + response=self._response, + request_parameters=self._request_parameters, + ) diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py new file mode 100644 index 000000000..324346160 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request_attributes_extractor.py @@ -0,0 +1,166 @@ +import json +import logging +from enum import Enum +from types import ModuleType +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Iterator, + List, + Mapping, + Tuple, + Type, +) + +from openinference.instrumentation.openai._utils import _OPENAI_VERSION +from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes +from opentelemetry.util.types import AttributeValue + +if TYPE_CHECKING: + from openai.types import Completion, CreateEmbeddingResponse + from openai.types.chat import ChatCompletion + +__all__ = ("_RequestAttributesExtractor",) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _RequestAttributesExtractor: + __slots__ = ( + "_openai", + "_chat_completion_type", + "_completion_type", + "_create_embedding_response_type", + ) + + def __init__(self, openai: ModuleType) -> None: + self._openai = openai + self._chat_completion_type: Type["ChatCompletion"] = openai.types.chat.ChatCompletion + self._completion_type: Type["Completion"] = openai.types.Completion + self._create_embedding_response_type: Type[ + "CreateEmbeddingResponse" + ] = openai.types.CreateEmbeddingResponse + + def get_attributes_from_request( + self, + cast_to: type, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + if not isinstance(request_parameters, Mapping): + return + if cast_to is self._chat_completion_type: + yield from _get_attributes_from_chat_completion_create_param(request_parameters) + elif cast_to is self._create_embedding_response_type: + yield from _get_attributes_from_embedding_create_param(request_parameters) + elif cast_to is self._completion_type: + yield from _get_attributes_from_completion_create_param(request_parameters) + else: + try: + yield SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(request_parameters) + except Exception: + logger.exception("Failed to serialize request options") + + +def _get_attributes_from_chat_completion_create_param( + params: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.chat.completion_create_params.CompletionCreateParamsBase + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/completion_create_params.py#L28 # noqa: E501 + if not isinstance(params, Mapping): + return + invocation_params = dict(params) + invocation_params.pop("messages", None) + invocation_params.pop("functions", None) + invocation_params.pop("tools", None) + yield SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(invocation_params) + if (input_messages := params.get("messages")) and isinstance(input_messages, Iterable): + # Use reversed() to get the last message first. This is because OTEL has a default limit of + # 128 attributes per span, and flattening increases the number of attributes very quickly. + for index, input_message in reversed(list(enumerate(input_messages))): + for key, value in _get_attributes_from_message_param(input_message): + yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value + + +def _get_attributes_from_message_param( + message: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.chat.ChatCompletionMessageParam + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_param.py#L15 # noqa: E501 + if not hasattr(message, "get"): + return + if role := message.get("role"): + yield ( + MessageAttributes.MESSAGE_ROLE, + role.value if isinstance(role, Enum) else role, + ) + if content := message.get("content"): + if isinstance(content, str): + yield MessageAttributes.MESSAGE_CONTENT, content + elif isinstance(content, List): + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_user_message_param.py#L14 # noqa: E501 + try: + json_string = json.dumps(content) + except Exception: + logger.exception("Failed to serialize message content") + else: + yield MessageAttributes.MESSAGE_CONTENT, json_string + if name := message.get("name"): + yield MessageAttributes.MESSAGE_NAME, name + if (function_call := message.get("function_call")) and hasattr(function_call, "get"): + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_assistant_message_param.py#L13 # noqa: E501 + if function_name := function_call.get("name"): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name + if function_arguments := function_call.get("arguments"): + yield ( + MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, + function_arguments, + ) + if ( + _OPENAI_VERSION >= (1, 1, 0) + and (tool_calls := message.get("tool_calls"),) + and isinstance(tool_calls, Iterable) + ): + for index, tool_call in enumerate(tool_calls): + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call_param.py#L23 # noqa: E501 + if not hasattr(tool_call, "get"): + continue + if (function := tool_call.get("function")) and hasattr(function, "get"): + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call_param.py#L10 # noqa: E501 + if name := function.get("name"): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", + name, + ) + if arguments := function.get("arguments"): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + arguments, + ) + + +def _get_attributes_from_completion_create_param( + params: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.completion_create_params.CompletionCreateParamsBase + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/completion_create_params.py#L11 # noqa: E501 + if not isinstance(params, Mapping): + return + invocation_params = dict(params) + invocation_params.pop("prompt", None) + yield SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(invocation_params) + + +def _get_attributes_from_embedding_create_param( + params: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.EmbeddingCreateParams + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/embedding_create_params.py#L11 # noqa: E501 + if not isinstance(params, Mapping): + return + invocation_params = dict(params) + invocation_params.pop("input", None) + yield SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(invocation_params) diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_accumulator.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_accumulator.py new file mode 100644 index 000000000..04e43f740 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_accumulator.py @@ -0,0 +1,265 @@ +import json +import warnings +from collections import defaultdict +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Callable, + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Protocol, + Tuple, + Type, +) + +from openinference.instrumentation.openai._utils import ( + _as_output_attributes, + _ValueAndType, +) +from openinference.semconv.trace import OpenInferenceMimeTypeValues +from opentelemetry.util.types import AttributeValue + +if TYPE_CHECKING: + from openai.types import Completion + from openai.types.chat import ChatCompletion, ChatCompletionChunk + +__all__ = ( + "_CompletionAccumulator", + "_ChatCompletionAccumulator", +) + + +class _CanGetAttributesFromResponse(Protocol): + def get_attributes_from_response( + self, + response: Any, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + ... + + +class _ChatCompletionAccumulator: + __slots__ = ( + "_is_null", + "_values", + "_cached_result", + "_request_parameters", + "_response_attributes_extractor", + "_chat_completion_type", + ) + + def __init__( + self, + request_parameters: Mapping[str, Any], + chat_completion_type: Type["ChatCompletion"], + response_attributes_extractor: Optional[_CanGetAttributesFromResponse] = None, + ) -> None: + self._chat_completion_type = chat_completion_type + self._request_parameters = request_parameters + self._response_attributes_extractor = response_attributes_extractor + self._is_null = True + self._cached_result: Optional[Dict[str, Any]] = None + self._values = _ValuesAccumulator( + choices=_IndexedAccumulator( + lambda: _ValuesAccumulator( + message=_ValuesAccumulator( + content=_StringAccumulator(), + function_call=_ValuesAccumulator(arguments=_StringAccumulator()), + tool_calls=_IndexedAccumulator( + lambda: _ValuesAccumulator( + function=_ValuesAccumulator(arguments=_StringAccumulator()), + ) + ), + ), + ), + ), + ) + + def process_chunk(self, chunk: "ChatCompletionChunk") -> None: + self._is_null = False + self._cached_result = None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # `warnings=False` in `model_dump()` is only supported in Pydantic v2 + values = chunk.model_dump(exclude_unset=True) + for choice in values.get("choices", ()): + if delta := choice.pop("delta", None): + choice["message"] = delta + self._values += values + + def _result(self) -> Optional[Dict[str, Any]]: + if self._is_null: + return None + if not self._cached_result: + self._cached_result = dict(self._values) + return self._cached_result + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._result()): + return + json_string = json.dumps(result) + yield from _as_output_attributes( + _ValueAndType(json_string, OpenInferenceMimeTypeValues.JSON) + ) + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._result()): + return + if self._response_attributes_extractor: + yield from self._response_attributes_extractor.get_attributes_from_response( + self._chat_completion_type.construct(**result), + self._request_parameters, + ) + + +class _CompletionAccumulator: + __slots__ = ( + "_is_null", + "_values", + "_cached_result", + "_request_parameters", + "_response_attributes_extractor", + "_completion_type", + ) + + def __init__( + self, + request_parameters: Mapping[str, Any], + completion_type: Type["Completion"], + response_attributes_extractor: Optional[_CanGetAttributesFromResponse] = None, + ) -> None: + self._completion_type = completion_type + self._request_parameters = request_parameters + self._response_attributes_extractor = response_attributes_extractor + self._is_null = True + self._cached_result: Optional[Dict[str, Any]] = None + self._values = _ValuesAccumulator( + choices=_IndexedAccumulator(lambda: _ValuesAccumulator(text=_StringAccumulator())), + ) + + def process_chunk(self, chunk: "Completion") -> None: + self._is_null = False + self._cached_result = None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # `warnings=False` in `model_dump()` is only supported in Pydantic v2 + values = chunk.model_dump(exclude_unset=True) + self._values += values + + def _result(self) -> Optional[Dict[str, Any]]: + if self._is_null: + return None + if not self._cached_result: + self._cached_result = dict(self._values) + return self._cached_result + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._result()): + return + json_string = json.dumps(result) + yield from _as_output_attributes( + _ValueAndType(json_string, OpenInferenceMimeTypeValues.JSON) + ) + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._result()): + return + if self._response_attributes_extractor: + yield from self._response_attributes_extractor.get_attributes_from_response( + self._completion_type.construct(**result), + self._request_parameters, + ) + + +class _ValuesAccumulator: + __slots__ = ("_values",) + + def __init__(self, **values: Any) -> None: + self._values: Dict[str, Any] = values + + def __iter__(self) -> Iterator[Tuple[str, Any]]: + for key, value in self._values.items(): + if value is None: + continue + if isinstance(value, _ValuesAccumulator): + if dict_value := dict(value): + yield key, dict_value + elif isinstance(value, _IndexedAccumulator): + if list_value := list(value): + yield key, list_value + elif isinstance(value, _StringAccumulator): + if str_value := str(value): + yield key, str_value + else: + yield key, value + + def __iadd__(self, values: Optional[Mapping[str, Any]]) -> "_ValuesAccumulator": + if not values: + return self + for key in self._values.keys(): + if (value := values.get(key)) is None: + continue + self_value = self._values[key] + if isinstance(self_value, _ValuesAccumulator): + if isinstance(value, Mapping): + self_value += value + elif isinstance(self_value, _StringAccumulator): + if isinstance(value, str): + self_value += value + elif isinstance(self_value, _IndexedAccumulator): + if isinstance(value, Iterable): + for v in value: + self_value += v + else: + self_value += value + elif isinstance(self_value, List) and isinstance(value, Iterable): + self_value.extend(value) + else: + self._values[key] = value # replacement + for key in values.keys(): + if key in self._values or (value := values[key]) is None: + continue + value = deepcopy(value) + if isinstance(value, Mapping): + value = _ValuesAccumulator(**value) + self._values[key] = value # new entry + return self + + +class _StringAccumulator: + __slots__ = ("_fragments",) + + def __init__(self) -> None: + self._fragments: List[str] = [] + + def __str__(self) -> str: + return "".join(self._fragments) + + def __iadd__(self, value: Optional[str]) -> "_StringAccumulator": + if not value: + return self + self._fragments.append(value) + return self + + +class _IndexedAccumulator: + __slots__ = ("_indexed",) + + def __init__(self, factory: Callable[[], _ValuesAccumulator]) -> None: + self._indexed: DefaultDict[int, _ValuesAccumulator] = defaultdict(factory) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + for _, values in sorted(self._indexed.items()): + yield dict(values) + + def __iadd__(self, values: Optional[Mapping[str, Any]]) -> "_IndexedAccumulator": + if not values or not hasattr(values, "get") or (index := values.get("index")) is None: + return self + self._indexed[index] += values + return self diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py new file mode 100644 index 000000000..7d07f558a --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_response_attributes_extractor.py @@ -0,0 +1,237 @@ +import base64 +import logging +from importlib import import_module +from types import ModuleType +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Iterator, + Mapping, + Optional, + Sequence, + Tuple, + Type, +) + +from openinference.instrumentation.openai._utils import ( + _OPENAI_VERSION, + _get_texts, +) +from openinference.semconv.trace import ( + EmbeddingAttributes, + MessageAttributes, + SpanAttributes, + ToolCallAttributes, +) +from opentelemetry.util.types import AttributeValue + +if TYPE_CHECKING: + from openai.types import Completion, CreateEmbeddingResponse + from openai.types.chat import ChatCompletion + +__all__ = ("_ResponseAttributesExtractor",) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +try: + _NUMPY: Optional[ModuleType] = import_module("numpy") +except ImportError: + _NUMPY = None + + +class _ResponseAttributesExtractor: + __slots__ = ( + "_openai", + "_chat_completion_type", + "_completion_type", + "_create_embedding_response_type", + ) + + def __init__(self, openai: ModuleType) -> None: + self._openai = openai + self._chat_completion_type: Type["ChatCompletion"] = openai.types.chat.ChatCompletion + self._completion_type: Type["Completion"] = openai.types.Completion + self._create_embedding_response_type: Type[ + "CreateEmbeddingResponse" + ] = openai.types.CreateEmbeddingResponse + + def get_attributes_from_response( + self, + response: Any, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + if isinstance(response, self._chat_completion_type): + yield from _get_attributes_from_chat_completion( + completion=response, + request_parameters=request_parameters, + ) + elif isinstance(response, self._create_embedding_response_type): + yield from _get_attributes_from_create_embedding_response( + response=response, + request_parameters=request_parameters, + ) + elif isinstance(response, self._completion_type): + yield from _get_attributes_from_completion( + completion=response, + request_parameters=request_parameters, + ) + else: + yield from () + + +def _get_attributes_from_chat_completion( + completion: "ChatCompletion", + request_parameters: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion.py#L40 # noqa: E501 + if model := getattr(completion, "model", None): + yield SpanAttributes.LLM_MODEL_NAME, model + if usage := getattr(completion, "usage", None): + yield from _get_attributes_from_completion_usage(usage) + if (choices := getattr(completion, "choices", None)) and isinstance(choices, Iterable): + for choice in choices: + if (index := getattr(choice, "index", None)) is None: + continue + if message := getattr(choice, "message", None): + for key, value in _get_attributes_from_chat_completion_message(message): + yield f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{index}.{key}", value + + +def _get_attributes_from_completion( + completion: "Completion", + request_parameters: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/completion.py#L13 # noqa: E501 + if model := getattr(completion, "model", None): + yield SpanAttributes.LLM_MODEL_NAME, model + if usage := getattr(completion, "usage", None): + yield from _get_attributes_from_completion_usage(usage) + if model_prompt := request_parameters.get("prompt"): + # FIXME: this step should move to request attributes extractor if decoding is not necessary. + # prompt: Required[Union[str, List[str], List[int], List[List[int]], None]] + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/completion_create_params.py#L38 # noqa: E501 + # FIXME: tokens (List[int], List[List[int]]) can't be decoded reliably because model + # names are not reliable (across OpenAI and Azure). + if prompts := list(_get_texts(model_prompt, model)): + yield SpanAttributes.LLM_PROMPTS, prompts + + +def _get_attributes_from_create_embedding_response( + response: "CreateEmbeddingResponse", + request_parameters: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/create_embedding_response.py#L20 # noqa: E501 + if usage := getattr(response, "usage", None): + yield from _get_attributes_from_embedding_usage(usage) + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/embedding_create_params.py#L23 # noqa: E501 + if model := getattr(response, "model"): + yield f"{SpanAttributes.EMBEDDING_MODEL_NAME}", model + if (data := getattr(response, "data", None)) and isinstance(data, Iterable): + for embedding in data: + if (index := getattr(embedding, "index", None)) is None: + continue + for key, value in _get_attributes_from_embedding(embedding): + yield f"{SpanAttributes.EMBEDDING_EMBEDDINGS}.{index}.{key}", value + embedding_input = request_parameters.get("input") + for index, text in enumerate(_get_texts(embedding_input, model)): + # FIXME: this step should move to request attributes extractor if decoding is not necessary. + # input: Required[Union[str, List[str], List[int], List[List[int]]]] + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/embedding_create_params.py#L12 # noqa: E501 + # FIXME: tokens (List[int], List[List[int]]) can't be decoded reliably because model + # names are not reliable (across OpenAI and Azure). + yield ( + ( + f"{SpanAttributes.EMBEDDING_EMBEDDINGS}.{index}." + f"{EmbeddingAttributes.EMBEDDING_TEXT}" + ), + text, + ) + + +def _get_attributes_from_embedding( + embedding: object, +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.Embedding + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/embedding.py#L11 # noqa: E501 + if not (_vector := getattr(embedding, "embedding", None)): + return + if isinstance(_vector, Sequence) and len(_vector) and isinstance(_vector[0], float): + vector = list(_vector) + yield f"{EmbeddingAttributes.EMBEDDING_VECTOR}", vector + elif isinstance(_vector, str) and _vector and _NUMPY: + # FIXME: this step should be removed if decoding is not necessary. + try: + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/resources/embeddings.py#L100 # noqa: E501 + vector = _NUMPY.frombuffer(base64.b64decode(_vector), dtype="float32").tolist() + except Exception: + logger.exception("Failed to decode embedding") + pass + else: + yield f"{EmbeddingAttributes.EMBEDDING_VECTOR}", vector + + +def _get_attributes_from_chat_completion_message( + message: object, +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.chat.ChatCompletionMessage + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message.py#L25 # noqa: E501 + if role := getattr(message, "role", None): + yield MessageAttributes.MESSAGE_ROLE, role + if content := getattr(message, "content", None): + yield MessageAttributes.MESSAGE_CONTENT, content + if function_call := getattr(message, "function_call", None): + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message.py#L12 # noqa: E501 + if name := getattr(function_call, "name", None): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, name + if arguments := getattr(function_call, "arguments", None): + yield MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, arguments + if ( + _OPENAI_VERSION >= (1, 1, 0) + and (tool_calls := getattr(message, "tool_calls", None)) + and isinstance(tool_calls, Iterable) + ): + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call.py#L23 # noqa: E501 + for index, tool_call in enumerate(tool_calls): + if function := getattr(tool_call, "function", None): + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/chat/chat_completion_message_tool_call.py#L10 # noqa: E501 + if name := getattr(function, "name", None): + yield ( + ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}" + ), + name, + ) + if arguments := getattr(function, "arguments", None): + yield ( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}." + f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + arguments, + ) + + +def _get_attributes_from_completion_usage( + usage: object, +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.CompletionUsage + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/completion_usage.py#L8 # noqa: E501 + if (total_tokens := getattr(usage, "total_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_tokens + if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, prompt_tokens + if (completion_tokens := getattr(usage, "completion_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, completion_tokens + + +def _get_attributes_from_embedding_usage( + usage: object, +) -> Iterator[Tuple[str, AttributeValue]]: + # openai.types.create_embedding_response.Usage + # See https://github.com/openai/openai-python/blob/f1c7d714914e3321ca2e72839fe2d132a8646e7f/src/openai/types/create_embedding_response.py#L12 # noqa: E501 + if (total_tokens := getattr(usage, "total_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_tokens + if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None: + yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, prompt_tokens diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py new file mode 100644 index 000000000..34576a9c9 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_stream.py @@ -0,0 +1,146 @@ +import logging +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Iterator, + Optional, + Protocol, + Tuple, + Union, +) + +from openinference.instrumentation.openai._utils import _finish_tracing +from openinference.instrumentation.openai._with_span import _WithSpan +from opentelemetry import trace as trace_api +from opentelemetry.util.types import AttributeValue +from wrapt import ObjectProxy + +if TYPE_CHECKING: + from openai import AsyncStream, Stream + +__all__ = ( + "_Stream", + "_ResponseAccumulator", +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _ResponseAccumulator(Protocol): + def process_chunk(self, chunk: Any) -> None: + ... + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + ... + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + ... + + +class _Stream(ObjectProxy): # type: ignore + __slots__ = ( + "_self_with_span", + "_self_iteration_count", + "_self_is_finished", + "_self_response_accumulator", + ) + + def __init__( + self, + stream: Union["Stream[Any]", "AsyncStream[Any]"], + with_span: _WithSpan, + response_accumulator: Optional[_ResponseAccumulator] = None, + ) -> None: + super().__init__(stream) + self._self_with_span = with_span + self._self_iteration_count = 0 + self._self_is_finished = with_span.is_finished + self._self_response_accumulator = response_accumulator + + def __iter__(self) -> Iterator[Any]: + return self + + def __next__(self) -> Any: + # pass through mistaken calls + if not hasattr(self.__wrapped__, "__next__"): + self.__wrapped__.__next__() + iteration_is_finished = False + status_code: Optional[trace_api.StatusCode] = None + try: + chunk: Any = self.__wrapped__.__next__() + except Exception as exception: + iteration_is_finished = True + if isinstance(exception, StopIteration): + status_code = trace_api.StatusCode.OK + else: + status_code = trace_api.StatusCode.ERROR + self._self_with_span.record_exception(exception) + raise + else: + self._process_chunk(chunk) + status_code = trace_api.StatusCode.OK + return chunk + finally: + if iteration_is_finished and not self._self_is_finished: + self._finish_tracing(status_code=status_code) + + def __aiter__(self) -> AsyncIterator[Any]: + return self + + async def __anext__(self) -> Any: + # pass through mistaken calls + if not hasattr(self.__wrapped__, "__anext__"): + self.__wrapped__.__anext__() + iteration_is_finished = False + status_code: Optional[trace_api.StatusCode] = None + try: + chunk: Any = await self.__wrapped__.__anext__() + except Exception as exception: + iteration_is_finished = True + if isinstance(exception, StopAsyncIteration): + status_code = trace_api.StatusCode.OK + else: + status_code = trace_api.StatusCode.ERROR + self._self_with_span.record_exception(exception) + raise + else: + self._process_chunk(chunk) + status_code = trace_api.StatusCode.OK + return chunk + finally: + if iteration_is_finished and not self._self_is_finished: + self._finish_tracing(status_code=status_code) + + def _process_chunk(self, chunk: Any) -> None: + if not self._self_iteration_count: + try: + self._self_with_span.add_event("First Token Stream Event") + except Exception: + logger.exception("Failed to add event to span") + self._self_iteration_count += 1 + if self._self_response_accumulator is not None: + try: + self._self_response_accumulator.process_chunk(chunk) + except Exception: + logger.exception("Failed to accumulate response") + + def _finish_tracing( + self, + status_code: Optional[trace_api.StatusCode] = None, + ) -> None: + _finish_tracing( + status_code=status_code, + with_span=self._self_with_span, + has_attributes=self, + ) + self._self_is_finished = True + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if self._self_response_accumulator is not None: + yield from self._self_response_accumulator.get_attributes() + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if self._self_response_accumulator is not None: + yield from self._self_response_accumulator.get_extra_attributes() diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_utils.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_utils.py new file mode 100644 index 000000000..ab8ddb803 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_utils.py @@ -0,0 +1,132 @@ +import json +import logging +import warnings +from importlib.metadata import version +from typing import ( + Any, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Protocol, + Sequence, + Tuple, + Union, + cast, +) + +from openinference.instrumentation.openai._with_span import _WithSpan +from openinference.semconv.trace import OpenInferenceMimeTypeValues, SpanAttributes +from opentelemetry import trace as trace_api +from opentelemetry.util.types import Attributes, AttributeValue + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +_OPENAI_VERSION = tuple(map(int, version("openai").split(".")[:3])) + + +class _ValueAndType(NamedTuple): + value: str + type: OpenInferenceMimeTypeValues + + +def _io_value_and_type(obj: Any) -> _ValueAndType: + if hasattr(obj, "model_dump_json") and callable(obj.model_dump_json): + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # `warnings=False` in `model_dump_json()` is only supported in Pydantic v2 + value = obj.model_dump_json(exclude_unset=True) + assert isinstance(value, str) + except Exception: + logger.exception("Failed to get model dump json") + else: + return _ValueAndType(value, OpenInferenceMimeTypeValues.JSON) + if not isinstance(obj, str) and isinstance(obj, (Sequence, Mapping)): + try: + value = json.dumps(obj) + except Exception: + logger.exception("Failed to dump json") + else: + return _ValueAndType(value, OpenInferenceMimeTypeValues.JSON) + return _ValueAndType(str(obj), OpenInferenceMimeTypeValues.TEXT) + + +def _as_input_attributes( + value_and_type: Optional[_ValueAndType], +) -> Iterator[Tuple[str, AttributeValue]]: + if not value_and_type: + return + yield SpanAttributes.INPUT_VALUE, value_and_type.value + # It's assumed to be TEXT by default, so we can skip to save one attribute. + if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: + yield SpanAttributes.INPUT_MIME_TYPE, value_and_type.type.value + + +def _as_output_attributes( + value_and_type: Optional[_ValueAndType], +) -> Iterator[Tuple[str, AttributeValue]]: + if not value_and_type: + return + yield SpanAttributes.OUTPUT_VALUE, value_and_type.value + # It's assumed to be TEXT by default, so we can skip to save one attribute. + if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: + yield SpanAttributes.OUTPUT_MIME_TYPE, value_and_type.type.value + + +class _HasAttributes(Protocol): + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + ... + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + ... + + +def _finish_tracing( + with_span: _WithSpan, + has_attributes: _HasAttributes, + status_code: Optional[trace_api.StatusCode] = None, +) -> None: + try: + attributes: Attributes = dict(has_attributes.get_attributes()) + except Exception: + logger.exception("Failed to get attributes") + attributes = None + try: + extra_attributes: Attributes = dict(has_attributes.get_extra_attributes()) + except Exception: + logger.exception("Failed to get extra attributes") + extra_attributes = None + try: + with_span.finish_tracing( + status_code=status_code, + attributes=attributes, + extra_attributes=extra_attributes, + ) + except Exception: + logger.exception("Failed to finish tracing") + + +def _get_texts( + model_input: Optional[Union[str, List[str], List[int], List[List[int]]]], + model: Optional[str], +) -> Iterator[str]: + if not model_input: + return + if isinstance(model_input, str): + text = model_input + yield text + return + if not isinstance(model_input, Sequence): + return + if any(not isinstance(item, str) for item in model_input): + # FIXME: We can't decode tokens (List[int]) reliably because the model name is not reliable, + # e.g. for text-embedding-ada-002 (cl100k_base), OpenAI returns "text-embedding-ada-002-v2", + # and Azure returns "ada", which refers to a different model (r50k_base). We could use the + # request model name instead, but that doesn't work for Azure because Azure uses the + # deployment name (which differs from the model name). + return + for text in cast(List[str], model_input): + yield text diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_with_span.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_with_span.py new file mode 100644 index 000000000..e09e18830 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_with_span.py @@ -0,0 +1,82 @@ +import logging +from typing import Optional + +from opentelemetry import trace as trace_api +from opentelemetry.util.types import Attributes + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _WithSpan: + __slots__ = ( + "_span", + "_extra_attributes", + "_is_finished", + ) + + def __init__( + self, + span: trace_api.Span, + extra_attributes: Attributes = None, + ) -> None: + self._span = span + self._extra_attributes = extra_attributes + try: + self._is_finished = not self._span.is_recording() + except Exception: + logger.exception("Failed to check if span is recording") + self._is_finished = True + + @property + def is_finished(self) -> bool: + return self._is_finished + + def record_exception(self, exception: Exception) -> None: + if self._is_finished: + return + try: + self._span.record_exception(exception) + except Exception: + logger.exception("Failed to record exception on span") + + def add_event(self, name: str) -> None: + if self._is_finished: + return + try: + self._span.add_event(name) + except Exception: + logger.exception("Failed to add event to span") + + def finish_tracing( + self, + status_code: Optional[trace_api.StatusCode] = None, + attributes: Attributes = None, + extra_attributes: Attributes = None, + ) -> None: + if self._is_finished: + return + for mapping in ( + attributes, + self._extra_attributes, + extra_attributes, + ): + if not mapping: + continue + for key, value in mapping.items(): + if value is None: + continue + try: + self._span.set_attribute(key, value) + except Exception: + logger.exception("Failed to set attribute on span") + if status_code is not None: + try: + self._span.set_status(status_code) + except Exception: + logger.exception("Failed to set status code on span") + try: + self._span.end() + except Exception: + logger.exception("Failed to end span") + self._is_finished = True diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py new file mode 100644 index 000000000..5d55e44fe --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py @@ -0,0 +1,661 @@ +import asyncio +import json +import logging +import random +from contextlib import suppress +from importlib import import_module +from importlib.metadata import version +from itertools import count +from typing import ( + Any, + AsyncIterator, + Dict, + Generator, + Iterable, + Iterator, + List, + Mapping, + Sequence, + Tuple, + Union, + cast, +) + +import pytest +from httpx import AsyncByteStream, Response +from openinference.instrumentation.openai import OpenAIInstrumentor +from openinference.semconv.trace import ( + EmbeddingAttributes, + MessageAttributes, + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, + ToolCallAttributes, +) +from opentelemetry import trace as trace_api +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.util.types import AttributeValue +from respx import MockRouter + +for name, logger in logging.root.manager.loggerDict.items(): + if name.startswith("openinference.") and isinstance(logger, logging.Logger): + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + logger.addHandler(logging.StreamHandler()) + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.parametrize("is_raw", [False, True]) +@pytest.mark.parametrize("is_stream", [False, True]) +@pytest.mark.parametrize("status_code", [200, 400]) +def test_chat_completions( + is_async: bool, + is_raw: bool, + is_stream: bool, + status_code: int, + respx_mock: MockRouter, + in_memory_span_exporter: InMemorySpanExporter, + completion_usage: Dict[str, Any], + model_name: str, + chat_completion_mock_stream: Tuple[List[bytes], List[Dict[str, Any]]], +) -> None: + input_messages: List[Dict[str, Any]] = get_messages() + output_messages: List[Dict[str, Any]] = ( + chat_completion_mock_stream[1] if is_stream else get_messages() + ) + invocation_parameters = { + "stream": is_stream, + "model": randstr(), + "temperature": random.random(), + "n": len(output_messages), + } + url = "https://api.openai.com/v1/chat/completions" + respx_kwargs: Dict[str, Any] = { + **( + {"stream": MockAsyncByteStream(chat_completion_mock_stream[0])} + if is_stream + else { + "json": { + "choices": [ + {"index": i, "message": message, "finish_reason": "stop"} + for i, message in enumerate(output_messages) + ], + "model": model_name, + "usage": completion_usage, + } + } + ), + } + respx_mock.post(url).mock(return_value=Response(status_code=status_code, **respx_kwargs)) + create_kwargs = {"messages": input_messages, **invocation_parameters} + openai = import_module("openai") + completions = ( + openai.AsyncOpenAI(api_key="sk-").chat.completions + if is_async + else openai.OpenAI(api_key="sk-").chat.completions + ) + create = completions.with_raw_response.create if is_raw else completions.create + with suppress(openai.BadRequestError): + if is_async: + + async def task() -> None: + response = await create(**create_kwargs) + response = response.parse() if is_raw else response + if is_stream: + async for _ in response: + pass + + asyncio.run(task()) + else: + response = create(**create_kwargs) + response = response.parse() if is_raw else response + if is_stream: + for _ in response: + pass + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 2 # first span should be from the httpx instrumentor + span: ReadableSpan = spans[1] + if status_code == 200: + assert span.status.is_ok + elif status_code == 400: + assert not span.status.is_ok and not span.status.is_unset + assert len(span.events) == 1 + event = span.events[0] + assert event.name == "exception" + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value + assert isinstance(attributes.pop(INPUT_VALUE, None), str) + assert ( + OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE, None)) + == OpenInferenceMimeTypeValues.JSON + ) + assert ( + json.loads(cast(str, attributes.pop(LLM_INVOCATION_PARAMETERS, None))) + == invocation_parameters + ) + for prefix, messages in ( + (LLM_INPUT_MESSAGES, input_messages), + *(((LLM_OUTPUT_MESSAGES, output_messages),) if status_code == 200 else ()), + ): + for i, message in enumerate(messages): + assert attributes.pop(message_role(prefix, i), None) == message.get("role") + assert attributes.pop(message_content(prefix, i), None) == message.get("content") + if function_call := message.get("function_call"): + assert attributes.pop( + message_function_call_name(prefix, i), None + ) == function_call.get("name") + assert attributes.pop( + message_function_call_arguments(prefix, i), None + ) == function_call.get("arguments") + if _openai_version() >= (1, 1, 0) and (tool_calls := message.get("tool_calls")): + for j, tool_call in enumerate(tool_calls): + if function := tool_call.get("function"): + assert attributes.pop( + tool_call_function_name(prefix, i, j), None + ) == function.get("name") + assert attributes.pop( + tool_call_function_arguments(prefix, i, j), None + ) == function.get("arguments") + if status_code == 200: + assert isinstance(attributes.pop(OUTPUT_VALUE, None), str) + assert ( + OpenInferenceMimeTypeValues(attributes.pop(OUTPUT_MIME_TYPE, None)) + == OpenInferenceMimeTypeValues.JSON + ) + if not is_stream: + # Usage is not available for streaming in general. + assert attributes.pop(LLM_TOKEN_COUNT_TOTAL, None) == completion_usage["total_tokens"] + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) == completion_usage["prompt_tokens"] + assert ( + attributes.pop(LLM_TOKEN_COUNT_COMPLETION, None) + == completion_usage["completion_tokens"] + ) + # We left out model_name from our mock stream. + assert attributes.pop(LLM_MODEL_NAME, None) == model_name + assert attributes == {} # test should account for all span attributes + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.parametrize("is_raw", [False, True]) +@pytest.mark.parametrize("is_stream", [False, True]) +@pytest.mark.parametrize("status_code", [200, 400]) +def test_completions( + is_async: bool, + is_raw: bool, + is_stream: bool, + status_code: int, + respx_mock: MockRouter, + in_memory_span_exporter: InMemorySpanExporter, + completion_usage: Dict[str, Any], + model_name: str, + completion_mock_stream: Tuple[List[bytes], List[str]], +) -> None: + prompt: List[str] = get_texts() + output_texts: List[str] = completion_mock_stream[1] if is_stream else get_texts() + invocation_parameters = { + "stream": is_stream, + "model": randstr(), + "temperature": random.random(), + "n": len(output_texts), + } + url = "https://api.openai.com/v1/completions" + respx_kwargs: Dict[str, Any] = { + **( + {"stream": MockAsyncByteStream(completion_mock_stream[0])} + if is_stream + else { + "json": { + "choices": [ + {"index": i, "text": text, "finish_reason": "stop"} + for i, text in enumerate(output_texts) + ], + "model": model_name, + "usage": completion_usage, + } + } + ), + } + respx_mock.post(url).mock(return_value=Response(status_code=status_code, **respx_kwargs)) + create_kwargs = {"prompt": prompt, **invocation_parameters} + openai = import_module("openai") + completions = ( + openai.AsyncOpenAI(api_key="sk-").completions + if is_async + else openai.OpenAI(api_key="sk-").completions + ) + create = completions.with_raw_response.create if is_raw else completions.create + with suppress(openai.BadRequestError): + if is_async: + + async def task() -> None: + response = await create(**create_kwargs) + response = response.parse() if is_raw else response + if is_stream: + async for _ in response: + pass + + asyncio.run(task()) + else: + response = create(**create_kwargs) + response = response.parse() if is_raw else response + if is_stream: + for _ in response: + pass + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 2 # first span should be from the httpx instrumentor + span: ReadableSpan = spans[1] + if status_code == 200: + assert span.status.is_ok + elif status_code == 400: + assert not span.status.is_ok and not span.status.is_unset + assert len(span.events) == 1 + event = span.events[0] + assert event.name == "exception" + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value + assert ( + json.loads(cast(str, attributes.pop(LLM_INVOCATION_PARAMETERS, None))) + == invocation_parameters + ) + assert isinstance(attributes.pop(INPUT_VALUE, None), str) + assert isinstance(attributes.pop(INPUT_MIME_TYPE, None), str) + if status_code == 200: + assert isinstance(attributes.pop(OUTPUT_VALUE, None), str) + assert isinstance(attributes.pop(OUTPUT_MIME_TYPE, None), str) + assert list(cast(Sequence[str], attributes.pop(LLM_PROMPTS, None))) == prompt + if not is_stream: + # Usage is not available for streaming in general. + assert attributes.pop(LLM_TOKEN_COUNT_TOTAL, None) == completion_usage["total_tokens"] + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) == completion_usage["prompt_tokens"] + assert ( + attributes.pop(LLM_TOKEN_COUNT_COMPLETION, None) + == completion_usage["completion_tokens"] + ) + # We left out model_name from our mock stream. + assert attributes.pop(LLM_MODEL_NAME, None) == model_name + assert attributes == {} # test should account for all span attributes + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.parametrize("is_raw", [False, True]) +@pytest.mark.parametrize("status_code", [200, 400]) +@pytest.mark.parametrize("encoding_format", ["float", "base64"]) +@pytest.mark.parametrize("input_text", ["hello", ["hello", "world"]]) +def test_embeddings( + is_async: bool, + is_raw: bool, + encoding_format: str, + input_text: Union[str, List[str]], + status_code: int, + respx_mock: MockRouter, + in_memory_span_exporter: InMemorySpanExporter, + model_name: str, +) -> None: + invocation_parameters = { + "model": randstr(), + "encoding_format": encoding_format, + } + embedding_model_name = randstr() + embedding_usage = { + "prompt_tokens": random.randint(10, 100), + "total_tokens": random.randint(10, 100), + } + output_embeddings = [("AACAPwAAAEA=", (1.0, 2.0)), ((2.0, 3.0), (2.0, 3.0))] + url = "https://api.openai.com/v1/embeddings" + respx_mock.post(url).mock( + return_value=Response( + status_code=status_code, + json={ + "object": "list", + "data": [ + {"object": "embedding", "index": i, "embedding": embedding[0]} + for i, embedding in enumerate(output_embeddings) + ], + "model": embedding_model_name, + "usage": embedding_usage, + }, + ) + ) + create_kwargs = {"input": input_text, **invocation_parameters} + openai = import_module("openai") + completions = ( + openai.AsyncOpenAI(api_key="sk-").embeddings + if is_async + else openai.OpenAI(api_key="sk-").embeddings + ) + create = completions.with_raw_response.create if is_raw else completions.create + with suppress(openai.BadRequestError): + if is_async: + + async def task() -> None: + response = await create(**create_kwargs) + _ = response.parse() if is_raw else response + + asyncio.run(task()) + else: + response = create(**create_kwargs) + _ = response.parse() if is_raw else response + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 2 # first span should be from the httpx instrumentor + span: ReadableSpan = spans[1] + if status_code == 200: + assert span.status.is_ok + elif status_code == 400: + assert not span.status.is_ok and not span.status.is_unset + assert len(span.events) == 1 + event = span.events[0] + assert event.name == "exception" + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + assert ( + attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.EMBEDDING.value + ) + assert ( + json.loads(cast(str, attributes.pop(LLM_INVOCATION_PARAMETERS, None))) + == invocation_parameters + ) + assert isinstance(attributes.pop(INPUT_VALUE, None), str) + assert isinstance(attributes.pop(INPUT_MIME_TYPE, None), str) + if status_code == 200: + assert isinstance(attributes.pop(OUTPUT_VALUE, None), str) + assert isinstance(attributes.pop(OUTPUT_MIME_TYPE, None), str) + assert attributes.pop(EMBEDDING_MODEL_NAME, None) == embedding_model_name + assert attributes.pop(LLM_TOKEN_COUNT_TOTAL, None) == embedding_usage["total_tokens"] + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) == embedding_usage["prompt_tokens"] + for i, text in enumerate(input_text if isinstance(input_text, list) else [input_text]): + assert attributes.pop(f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_TEXT}", None) == text + for i, embedding in enumerate(output_embeddings): + assert ( + attributes.pop(f"{EMBEDDING_EMBEDDINGS}.{i}.{EMBEDDING_VECTOR}", None) + == embedding[1] + ) + assert attributes == {} # test should account for all span attributes + + +@pytest.fixture(scope="module") +def in_memory_span_exporter() -> InMemorySpanExporter: + return InMemorySpanExporter() + + +@pytest.fixture(scope="module") +def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> trace_api.TracerProvider: + resource = Resource(attributes={}) + tracer_provider = trace_sdk.TracerProvider(resource=resource) + span_processor = SimpleSpanProcessor(span_exporter=in_memory_span_exporter) + tracer_provider.add_span_processor(span_processor=span_processor) + HTTPXClientInstrumentor().instrument(tracer_provider=tracer_provider) + return tracer_provider + + +@pytest.fixture(autouse=True) +def instrument( + tracer_provider: trace_api.TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, +) -> Generator[None, None, None]: + OpenAIInstrumentor().instrument(tracer_provider=tracer_provider) + yield + OpenAIInstrumentor().uninstrument() + in_memory_span_exporter.clear() + + +@pytest.fixture(scope="module") +def seed() -> Iterator[int]: + """ + Use rolling seeds to help debugging, because the rolling pseudo-random values + allow conditional breakpoints to be hit precisely (and repeatably). + """ + return count() + + +@pytest.fixture(autouse=True) +def set_seed(seed: Iterator[int]) -> Iterator[None]: + random.seed(next(seed)) + yield + + +@pytest.fixture +def completion_usage() -> Dict[str, Any]: + prompt_tokens = random.randint(1, 1000) + completion_tokens = random.randint(1, 1000) + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + + +@pytest.fixture +def model_name() -> str: + return randstr() + + +@pytest.fixture +def input_messages() -> List[Dict[str, Any]]: + return [{"role": randstr(), "content": randstr()} for _ in range(2)] + + +@pytest.fixture +def chat_completion_mock_stream() -> Tuple[List[bytes], List[Dict[str, Any]]]: + return ( + [ + b'data: {"choices": [{"delta": {"role": "assistant"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "id": "call_amGrubFmr2FSPHeC5OPgwcNs", "function": {"arguments": "", "name": "get_current_weather"}, "type": "function"}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": ""}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\"lo"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "{\\"lo"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "catio"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "catio"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "n\\": \\"B"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "n\\": \\"B"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "osto"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "osto"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "n, MA"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "n, MA"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "\\", \\"un"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "\\", \\"un"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "it\\":"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "it\\":"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": " \\"fah"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": " \\"fah"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "renhei"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "renhei"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "t\\"}"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "t\\"}"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "id": "call_6QTP4mLSYYzZwt3ZWj77vfZf", "function": {"arguments": "", "name": "get_current_weather"}, "type": "function"}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"role": "assistant"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "{\\"lo"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "{\\"lo"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "catio"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "catio"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "n\\": \\"S"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "n\\": \\"S"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "an F"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "an F"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "ranci"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "ranci"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "sco, C"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "sco, C"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "A\\", "}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "A\\", "}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "\\"unit"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "\\"unit"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "\\": \\"fa"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "\\": \\"fa"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "hren"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "hren"}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "heit\\""}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "heit\\""}, "index": 1}]}\n\n', + b'data: {"choices": [{"delta": {"tool_calls": [{"index": 1, "function": {"arguments": "}"}}]}, "index": 0}]}\n\n', # noqa: E501 + b'data: {"choices": [{"delta": {"content": "}"}, "index": 1}]}\n\n', + b'data: {"choices": [{"finish_reason": "tool_calls", "index": 0}]}\n\n', # noqa: E501 + b"data: [DONE]\n", + ], + [ + { + "role": "assistant", + "content": '{"location": "Boston, MA", "unit": "fahrenheit"}', + "tool_calls": [ + { + "id": "call_amGrubFmr2FSPHeC5OPgwcNs", + "function": { + "arguments": '{"location": "Boston, MA", "unit": "fahrenheit"}', + "name": "get_current_weather", + }, + "type": "function", + }, + { + "id": "call_6QTP4mLSYYzZwt3ZWj77vfZf", + "function": { + "arguments": '{"location": "San Francisco, CA", "unit": "fahrenheit"}', + "name": "get_current_weather", + }, + "type": "function", + }, + ], + }, + { + "role": "assistant", + "content": '{"location": "San Francisco, CA", "unit": "fahrenheit"}', + }, + ], + ) + + +@pytest.fixture +def completion_mock_stream() -> Tuple[List[bytes], List[str]]: + return ( + [ + b'data: {"choices": [{"text": "", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "{\\"lo", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "{\\"lo", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "catio", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "catio", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "n\\": \\"S", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "n\\": \\"B", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "an F", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "osto", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "ranci", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "n, MA", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "sco, C", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "\\", \\"un", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "A\\", ", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "it\\":", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "\\"unit", "index": 1}]}\n\n', + b'data: {"choices": [{"text": " \\"fah", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "\\": \\"fa", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "renhei", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "hren", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "t\\"}", "index": 0}]}\n\n', + b'data: {"choices": [{"text": "heit\\"", "index": 1}]}\n\n', + b'data: {"choices": [{"text": "}", "index": 1}]}\n\n', + b"data: [DONE]\n", + ], + [ + '{"location": "Boston, MA", "unit": "fahrenheit"}', + '{"location": "San Francisco, CA", "unit": "fahrenheit"}', + ], + ) + + +class MockAsyncByteStream(AsyncByteStream): + def __init__(self, byte_stream: Iterable[bytes]): + self._byte_stream = byte_stream + + def __iter__(self) -> Iterator[bytes]: + for byte_string in self._byte_stream: + yield byte_string + + async def __aiter__(self) -> AsyncIterator[bytes]: + for byte_string in self._byte_stream: + yield byte_string + + +def randstr() -> str: + return str(random.random()) + + +def get_texts() -> List[str]: + return [randstr() for _ in range(2)] + + +def get_messages() -> List[Dict[str, Any]]: + messages: List[Dict[str, Any]] = [ + *[{"role": randstr(), "content": randstr()} for _ in range(2)], + *[ + {"role": randstr(), "function_call": {"arguments": randstr(), "name": randstr()}} + for _ in range(2) + ], + *( + [ + { + "role": randstr(), + "tool_calls": [ + {"function": {"arguments": randstr(), "name": randstr()}} for _ in range(2) + ], + } + for _ in range(2) + ] + if _openai_version() >= (1, 1, 0) + else [] + ), + ] + random.shuffle(messages) + return messages + + +def _openai_version() -> Tuple[int, int, int]: + return cast(Tuple[int, int, int], tuple(map(int, version("openai").split(".")[:3]))) + + +def message_role(prefix: str, i: int) -> str: + return f"{prefix}.{i}.{MESSAGE_ROLE}" + + +def message_content(prefix: str, i: int) -> str: + return f"{prefix}.{i}.{MESSAGE_CONTENT}" + + +def message_function_call_name(prefix: str, i: int) -> str: + return f"{prefix}.{i}.{MESSAGE_FUNCTION_CALL_NAME}" + + +def message_function_call_arguments(prefix: str, i: int) -> str: + return f"{prefix}.{i}.{MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON}" + + +def tool_call_function_name(prefix: str, i: int, j: int) -> str: + return f"{prefix}.{i}.{MESSAGE_TOOL_CALLS}.{j}.{TOOL_CALL_FUNCTION_NAME}" + + +def tool_call_function_arguments(prefix: str, i: int, j: int) -> str: + return f"{prefix}.{i}.{MESSAGE_TOOL_CALLS}.{j}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}" + + +OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND +INPUT_VALUE = SpanAttributes.INPUT_VALUE +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE +LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS +LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME +LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES +LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES +LLM_PROMPTS = SpanAttributes.LLM_PROMPTS +MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE +MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT +MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME +MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON +MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS +TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME +TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON +EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS +EMBEDDING_MODEL_NAME = SpanAttributes.EMBEDDING_MODEL_NAME +EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR +EMBEDDING_TEXT = EmbeddingAttributes.EMBEDDING_TEXT diff --git a/python/mypy.ini b/python/mypy.ini new file mode 100644 index 000000000..533084ef6 --- /dev/null +++ b/python/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +strict = true + +[mypy-wrapt] +ignore_missing_imports = True diff --git a/python/ruff.toml b/python/ruff.toml new file mode 100644 index 000000000..afeb5281c --- /dev/null +++ b/python/ruff.toml @@ -0,0 +1,12 @@ +line-length = 100 +exclude = [ + ".git", + ".tox", + "dist", +] +ignore-init-module-imports = true +select = ["E", "F", "W", "I"] +target-version = "py38" + +[lint.isort] +force-single-line = false diff --git a/python/tox.ini b/python/tox.ini new file mode 100644 index 000000000..cb9edf390 --- /dev/null +++ b/python/tox.ini @@ -0,0 +1,28 @@ +[tox] +isolated_build = True +skipsdist = True +envlist = + py3{8,11}-ci-semconv + py3{8,11}-ci-{openai,openai-latest} + +[testenv] +package = wheel +wheel_build_env = .pkg +deps = + -r dev-requirements.txt +changedir = + semconv: openinference-semantic-conventions/ + openai: instrumentation/openinference-instrumentation-openai/ +commands_pre = + semconv: pip install {toxinidir}/openinference-semantic-conventions + openai: pip install {toxinidir}/instrumentation/openinference-instrumentation-openai[test] + openai-latest: pip install -U openai +commands = + ruff: ruff format . --config {toxinidir}/ruff.toml + ruff: ruff check . --fix --config {toxinidir}/ruff.toml + mypy: mypy --config-file {toxinidir}/mypy.ini --explicit-package-bases {posargs:src} + test: pytest {posargs:tests} + ci: ruff format . --diff --config {toxinidir}/ruff.toml + ci: ruff check . --diff --config {toxinidir}/ruff.toml + ci: mypy --config-file {toxinidir}/mypy.ini --explicit-package-bases src + ci: pytest tests