diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 91f07b2e3..fcbc1fc44 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1 +1 @@ -{"python/openinference-semantic-conventions":"0.1.2","python/instrumentation/openinference-instrumentation-openai":"0.1.0","python/instrumentation/openinference-instrumentation-llama-index":"0.1.0","python/instrumentation/openinference-instrumentation-dspy":"0.1.0"} +{"python/openinference-semantic-conventions":"0.1.2","python/instrumentation/openinference-instrumentation-openai":"0.1.0","python/instrumentation/openinference-instrumentation-llama-index":"0.1.0","python/instrumentation/openinference-instrumentation-dspy":"0.1.0","python/instrumentation/openinference-instrumentation-langchain":"0.1.0"} diff --git a/README.md b/README.md index ecdd93c00..b1af179d3 100644 --- a/README.md +++ b/README.md @@ -22,13 +22,14 @@ OpenInference provides a set of instrumentations for popular machine learning SD ## Python -| Package | Description | Version | -| ---------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`openinference-semantic-conventions`](./python/openinference-semantic-conventions/README.md) | Semantic conventions for tracing of LLM Apps. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-semantic-conventions.svg)](https://pypi.python.org/pypi/openinference-semantic-conventions) | -| [`openinference-instrumentation-openai`](./python/instrumentation/openinference-instrumentation-openai/README.rst) | OpenInference Instrumentation for OpenAI SDK. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-openai.svg)](https://pypi.python.org/pypi/openinference-instrumentation-openai) | -| [`openinference-instrumentation-llama-index`](./python/instrumentation/openinference-instrumentation-llama-index/README.rst) | OpenInference Instrumentation for LlamaIndex. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-llama-index.svg)](https://pypi.python.org/pypi/openinference-instrumentation-llama-index) | -| [`openinference-instrumentation-dspy`](./python/instrumentation/openinference-instrumentation-dspy/README.rst) | OpenInference Instrumentation for DSPy. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-dspy.svg)](https://pypi.python.org/pypi/openinference-instrumentation-dspy) | -| [`openinference-instrumentation-bedrock`](./python/instrumentation/openinference-instrumentation-bedrock/README.rst) | OpenInference Instrumentation for AWS Bedrock. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-bedrock.svg)](https://pypi.python.org/pypi/openinference-instrumentation-bedrock) | +| Package | Description | Version | +| ---------------------------------------------------------------------------------------------------------------------------- |------------------------------------------------| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`openinference-semantic-conventions`](./python/openinference-semantic-conventions/README.md) | Semantic conventions for tracing of LLM Apps. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-semantic-conventions.svg)](https://pypi.python.org/pypi/openinference-semantic-conventions) | +| [`openinference-instrumentation-openai`](./python/instrumentation/openinference-instrumentation-openai/README.rst) | OpenInference Instrumentation for OpenAI SDK. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-openai.svg)](https://pypi.python.org/pypi/openinference-instrumentation-openai) | +| [`openinference-instrumentation-llama-index`](./python/instrumentation/openinference-instrumentation-llama-index/README.rst) | OpenInference Instrumentation for LlamaIndex. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-llama-index.svg)](https://pypi.python.org/pypi/openinference-instrumentation-llama-index) | +| [`openinference-instrumentation-dspy`](./python/instrumentation/openinference-instrumentation-dspy/README.rst) | OpenInference Instrumentation for DSPy. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-dspy.svg)](https://pypi.python.org/pypi/openinference-instrumentation-dspy) | +| [`openinference-instrumentation-bedrock`](./python/instrumentation/openinference-instrumentation-bedrock/README.rst) | OpenInference Instrumentation for AWS Bedrock. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-bedrock.svg)](https://pypi.python.org/pypi/openinference-instrumentation-bedrock) | +| [`openinference-instrumentation-langchain`](./python/instrumentation/openinference-instrumentation-langchain/README.rst) | OpenInference Instrumentation for LangChain. | [![PyPI Version](https://img.shields.io/pypi/v/openinference-instrumentation-langchain.svg)](https://pypi.python.org/pypi/openinference-instrumentation-langchain) | ## JavaScript diff --git a/python/instrumentation/openinference-instrumentation-langchain/CHANGELOG.md b/python/instrumentation/openinference-instrumentation-langchain/CHANGELOG.md new file mode 100644 index 000000000..a71235a93 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/CHANGELOG.md @@ -0,0 +1,8 @@ +# Changelog + +## 0.1.0 (2024-01-26) + + +### Features + +* langchain instrumentor ([#138](https://github.com/Arize-ai/openinference/issues/138)) ([61094f6](https://github.com/Arize-ai/openinference/commit/61094f606dc0a6961fc566d8d45b27967a14c59c)) diff --git a/python/instrumentation/openinference-instrumentation-langchain/LICENSE b/python/instrumentation/openinference-instrumentation-langchain/LICENSE new file mode 100644 index 000000000..0223315cc --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright The OpenTelemetry Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/python/instrumentation/openinference-instrumentation-langchain/README.rst b/python/instrumentation/openinference-instrumentation-langchain/README.rst new file mode 100644 index 000000000..9c59b747e --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/README.rst @@ -0,0 +1,14 @@ +OpenInference LangChain Instrumentation +============================================= + +|pypi| + +.. |pypi| image:: https://badge.fury.io/py/openinference-instrumentation-langchain.svg + :target: https://pypi.org/project/openinference-instrumentation-langchain/ + +Installation +------------ + +:: + + pip install openinference-instrumentation-langchain diff --git a/python/instrumentation/openinference-instrumentation-langchain/examples/openai_chat_stream.py b/python/instrumentation/openinference-instrumentation-langchain/examples/openai_chat_stream.py new file mode 100644 index 000000000..9442346b0 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/examples/openai_chat_stream.py @@ -0,0 +1,22 @@ +from langchain_openai import ChatOpenAI +from openinference.instrumentation.langchain import LangChainInstrumentor +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor + +resource = Resource(attributes={}) +tracer_provider = trace_sdk.TracerProvider(resource=resource) +trace_api.set_tracer_provider(tracer_provider=tracer_provider) +span_otlp_exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:6006/v1/traces") +tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_otlp_exporter)) +span_console_exporter = ConsoleSpanExporter() +tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_console_exporter)) + +LangChainInstrumentor().instrument() + + +if __name__ == "__main__": + for chunk in ChatOpenAI(model_name="gpt-3.5-turbo").stream("Write a haiku."): + print(chunk.content, end="", flush=True) diff --git a/python/instrumentation/openinference-instrumentation-langchain/examples/requirements.txt b/python/instrumentation/openinference-instrumentation-langchain/examples/requirements.txt new file mode 100644 index 000000000..ac433c014 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/examples/requirements.txt @@ -0,0 +1,4 @@ +langchain_openai +opentelemetry-sdk +opentelemetry-exporter-otlp +openinference-instrumentation-langchain diff --git a/python/instrumentation/openinference-instrumentation-langchain/pyproject.toml b/python/instrumentation/openinference-instrumentation-langchain/pyproject.toml new file mode 100644 index 000000000..b42086e43 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/pyproject.toml @@ -0,0 +1,63 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "openinference-instrumentation-langchain" +dynamic = ["version"] +description = "OpenInference LangChain Instrumentation" +readme = "README.rst" +license = "Apache-2.0" +requires-python = ">=3.8, <3.12" +authors = [ + { name = "OpenInference Authors", email = "oss@arize.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dependencies = [ + "opentelemetry-api", + "opentelemetry-instrumentation", + "opentelemetry-semantic-conventions", + "openinference-semantic-conventions", +] + +[project.optional-dependencies] +instruments = [ + "langchain_core >= 0.1.0", +] +test = [ + "langchain_core == 0.1.8", + "langchain == 0.1.0", + "langchain_openai == 0.0.2", + "langchain-community == 0.0.10", + "opentelemetry-sdk", + "openinference-instrumentation-openai", + "respx", +] +type-check = [ + "langchain_core == 0.1.0", +] + +[project.urls] +Homepage = "https://github.com/Arize-ai/openinference/tree/main/python/instrumentation/openinference-instrumentation-langchain" + +[tool.hatch.version] +path = "src/openinference/instrumentation/langchain/version.py" + +[tool.hatch.build.targets.sdist] +include = [ + "/src", + "/tests", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/openinference"] diff --git a/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/__init__.py b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/__init__.py new file mode 100644 index 000000000..1ea8d72f8 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/__init__.py @@ -0,0 +1,61 @@ +import logging +from typing import TYPE_CHECKING, Any, Callable, Collection + +from openinference.instrumentation.langchain._tracer import OpenInferenceTracer +from openinference.instrumentation.langchain.package import _instruments +from openinference.instrumentation.langchain.version import __version__ +from opentelemetry import trace as trace_api +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore +from wrapt import wrap_function_wrapper + +if TYPE_CHECKING: + from langchain_core.callbacks import BaseCallbackManager + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class LangChainInstrumentor(BaseInstrumentor): # type: ignore + """ + An instrumentor for LangChain + """ + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs: Any) -> None: + if not (tracer_provider := kwargs.get("tracer_provider")): + tracer_provider = trace_api.get_tracer_provider() + tracer = trace_api.get_tracer(__name__, __version__, tracer_provider) + wrap_function_wrapper( + module="langchain_core.callbacks", + name="BaseCallbackManager.__init__", + wrapper=_BaseCallbackManagerInit(tracer=tracer), + ) + + def _uninstrument(self, **kwargs: Any) -> None: + pass + + +class _BaseCallbackManagerInit: + __slots__ = ("_tracer",) + + def __init__(self, tracer: trace_api.Tracer): + self._tracer = tracer + + def __call__( + self, + wrapped: Callable[..., None], + instance: "BaseCallbackManager", + args: Any, + kwargs: Any, + ) -> None: + wrapped(*args, **kwargs) + for handler in instance.inheritable_handlers: + # Handlers may be copied when new managers are created, so we + # don't want to keep adding. E.g. see the following location. + # https://github.com/langchain-ai/langchain/blob/5c2538b9f7fb64afed2a918b621d9d8681c7ae32/libs/core/langchain_core/callbacks/manager.py#L1876 # noqa: E501 + if isinstance(handler, OpenInferenceTracer): + break + else: + instance.add_handler(OpenInferenceTracer(tracer=self._tracer), True) diff --git a/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py new file mode 100644 index 000000000..0a75a314b --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py @@ -0,0 +1,524 @@ +import json +import logging +from copy import deepcopy +from datetime import datetime +from enum import Enum +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, +) +from uuid import UUID + +from langchain_core.tracers.base import BaseTracer +from langchain_core.tracers.schemas import Run +from openinference.semconv.trace import ( + DocumentAttributes, + EmbeddingAttributes, + MessageAttributes, + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + RerankerAttributes, + SpanAttributes, + ToolCallAttributes, +) +from opentelemetry import context as context_api +from opentelemetry import trace as trace_api +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.util.types import AttributeValue + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _Run(NamedTuple): + span: trace_api.Span + token: object # token for OTEL context API + + +class OpenInferenceTracer(BaseTracer): + __slots__ = ("_tracer", "_runs") + + def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._tracer = tracer + self._runs: Dict[UUID, _Run] = {} + # run_inline=True so the handler is not run in a thread. E.g. see the following location. + # https://github.com/langchain-ai/langchain/blob/5c2538b9f7fb64afed2a918b621d9d8681c7ae32/libs/core/langchain_core/callbacks/manager.py#L321 # noqa: E501 + self.run_inline = True + + def _start_trace(self, run: Run) -> None: + super()._start_trace(run) + span = self._tracer.start_span(run.name) + token = context_api.attach(trace_api.set_span_in_context(span)) + self._runs[run.id] = _Run(span=span, token=token) + + def _end_trace(self, run: Run) -> None: + if event_data := self._runs.pop(run.id, None): + # FIXME: find out why sometimes token fails to detach, e.g. when it's async + context_api.detach(event_data.token) + span = event_data.span + try: + _update_span(span, run) + except Exception: + logger.exception("Failed to update span with run data.") + span.end() + super()._end_trace(run) + + def _persist_run(self, run: Run) -> None: + pass + + def on_llm_error(self, error: BaseException, *args: Any, run_id: UUID, **kwargs: Any) -> Run: + if event_data := self._runs.get(run_id): + _record_exception(event_data.span, error) + return super().on_llm_error(error, *args, run_id=run_id, **kwargs) + + def on_chain_error(self, error: BaseException, *args: Any, run_id: UUID, **kwargs: Any) -> Run: + if event_data := self._runs.get(run_id): + _record_exception(event_data.span, error) + return super().on_chain_error(error, *args, run_id=run_id, **kwargs) + + def on_retriever_error( + self, error: BaseException, *args: Any, run_id: UUID, **kwargs: Any + ) -> Run: + if event_data := self._runs.get(run_id): + _record_exception(event_data.span, error) + return super().on_retriever_error(error, *args, run_id=run_id, **kwargs) + + def on_tool_error(self, error: BaseException, *args: Any, run_id: UUID, **kwargs: Any) -> Run: + if event_data := self._runs.get(run_id): + _record_exception(event_data.span, error) + return super().on_tool_error(error, *args, run_id=run_id, **kwargs) + + +def _record_exception(span: trace_api.Span, error: BaseException) -> None: + if isinstance(error, Exception): + span.record_exception(error) + else: + span.add_event( + name="exception", + attributes={ + OTELSpanAttributes.EXCEPTION_MESSAGE: str(error), + OTELSpanAttributes.EXCEPTION_TYPE: error.__class__.__name__, + }, + ) + + +def _update_span(span: trace_api.Span, run: Run) -> None: + if run.error is None: + span.set_status(trace_api.StatusCode.OK) + else: + span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, run.error)) + span_kind = ( + OpenInferenceSpanKindValues.AGENT + if "agent" in run.name.lower() + else _langchain_run_type_to_span_kind(run.run_type) + ) + span.set_attribute(OPENINFERENCE_SPAN_KIND, span_kind.value) + span.set_attributes( + dict( + _flatten( + chain( + _as_input(_convert_io(run.inputs)), + _as_output(_convert_io(run.outputs)), + _prompts(run.inputs), + _input_messages(run.inputs), + _output_messages(run.outputs), + _prompt_template(run.serialized), + _invocation_parameters(run), + _model_name(run.extra), + _token_counts(run.outputs), + _function_calls(run.outputs), + _tools(run), + _retrieval_documents(run), + ) + ) + ) + ) + + +def _langchain_run_type_to_span_kind(run_type: str) -> OpenInferenceSpanKindValues: + try: + return OpenInferenceSpanKindValues(run_type.upper()) + except ValueError: + return OpenInferenceSpanKindValues.UNKNOWN + + +def _serialize_json(obj: Any) -> str: + if isinstance(obj, datetime): + return obj.isoformat() + return str(obj) + + +def stop_on_exception( + wrapped: Callable[..., Iterator[Tuple[str, Any]]], +) -> Callable[..., Iterator[Tuple[str, Any]]]: + def wrapper(*args: Any, **kwargs: Any) -> Iterator[Tuple[str, Any]]: + try: + yield from wrapped(*args, **kwargs) + except Exception: + logger.exception("Failed to get attribute.") + + return wrapper + + +@stop_on_exception +def _flatten(key_values: Iterable[Tuple[str, Any]]) -> Iterator[Tuple[str, AttributeValue]]: + for key, value in key_values: + if value is None: + continue + if isinstance(value, Mapping): + for sub_key, sub_value in _flatten(value.items()): + yield f"{key}.{sub_key}", sub_value + elif isinstance(value, List) and any(isinstance(item, Mapping) for item in value): + for index, sub_mapping in enumerate(value): + for sub_key, sub_value in _flatten(sub_mapping.items()): + yield f"{key}.{index}.{sub_key}", sub_value + else: + if isinstance(value, Enum): + value = value.value + yield key, value + + +@stop_on_exception +def _as_input(values: Iterable[str]) -> Iterator[Tuple[str, str]]: + return zip((INPUT_VALUE, INPUT_MIME_TYPE), values) + + +@stop_on_exception +def _as_output(values: Iterable[str]) -> Iterator[Tuple[str, str]]: + return zip((OUTPUT_VALUE, OUTPUT_MIME_TYPE), values) + + +def _convert_io(obj: Optional[Mapping[str, Any]]) -> Iterator[str]: + if not obj: + return + assert isinstance(obj, dict), f"expected dict, found {type(obj)}" + if len(obj) == 1 and isinstance(value := next(iter(obj.values())), str): + yield value + else: + yield json.dumps(obj, default=_serialize_json) + yield OpenInferenceMimeTypeValues.JSON.value + + +@stop_on_exception +def _prompts(inputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, List[str]]]: + """Yields prompts if present.""" + if not inputs: + return + assert hasattr(inputs, "get"), f"expected Mapping, found {type(inputs)}" + if prompts := inputs.get("prompts"): + yield LLM_PROMPTS, prompts + + +@stop_on_exception +def _input_messages( + inputs: Optional[Mapping[str, Any]], +) -> Iterator[Tuple[str, List[Dict[str, Any]]]]: + """Yields chat messages if present.""" + if not inputs: + return + assert hasattr(inputs, "get"), f"expected Mapping, found {type(inputs)}" + # There may be more than one set of messages. We'll use just the first set. + if not (multiple_messages := inputs.get("messages")): + return + assert isinstance( + multiple_messages, Iterable + ), f"expected Iterable, found {type(multiple_messages)}" + # This will only get the first set of messages. + if not (first_messages := next(iter(multiple_messages), None)): + return + assert isinstance(first_messages, Iterable), f"expected Iterable, found {type(first_messages)}" + parsed_messages = [] + for message_data in first_messages: + assert hasattr(message_data, "get"), f"expected Mapping, found {type(message_data)}" + parsed_messages.append(dict(_parse_message_data(message_data))) + if parsed_messages: + yield LLM_INPUT_MESSAGES, parsed_messages + + +@stop_on_exception +def _output_messages( + outputs: Optional[Mapping[str, Any]], +) -> Iterator[Tuple[str, List[Dict[str, Any]]]]: + """Yields chat messages if present.""" + if not outputs: + return + assert hasattr(outputs, "get"), f"expected Mapping, found {type(outputs)}" + # There may be more than one set of generations. We'll use just the first set. + if not (multiple_generations := outputs.get("generations")): + return + assert isinstance( + multiple_generations, Iterable + ), f"expected Iterable, found {type(multiple_generations)}" + # This will only get the first set of generations. + if not (first_generations := next(iter(multiple_generations), None)): + return + assert isinstance( + first_generations, Iterable + ), f"expected Iterable, found {type(first_generations)}" + parsed_messages = [] + for generation in first_generations: + assert hasattr(generation, "get"), f"expected Mapping, found {type(generation)}" + if message_data := generation.get("message"): + assert hasattr(message_data, "get"), f"expected Mapping, found {type(message_data)}" + parsed_messages.append(dict(_parse_message_data(message_data))) + if parsed_messages: + yield LLM_OUTPUT_MESSAGES, parsed_messages + + +@stop_on_exception +def _parse_message_data(message_data: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, Any]]: + """Parses message data to grab message role, content, etc.""" + if not message_data: + return + assert hasattr(message_data, "get"), f"expected Mapping, found {type(message_data)}" + id_ = message_data.get("id") + assert isinstance(id_, List), f"expected list, found {type(id_)}" + message_class_name = id_[-1] + if message_class_name.startswith("HumanMessage"): + role = "user" + elif message_class_name.startswith("AIMessage"): + role = "assistant" + elif message_class_name.startswith("SystemMessage"): + role = "system" + elif message_class_name.startswith("FunctionMessage"): + role = "function" + elif message_class_name.startswith("ChatMessage"): + role = message_data["kwargs"]["role"] + else: + raise ValueError(f"Cannot parse message of type: {message_class_name}") + yield MESSAGE_ROLE, role + if kwargs := message_data.get("kwargs"): + assert hasattr(kwargs, "get"), f"expected Mapping, found {type(kwargs)}" + if content := kwargs.get("content"): + assert isinstance(content, str), f"expected str, found {type(content)}" + yield MESSAGE_CONTENT, content + if additional_kwargs := kwargs.get("additional_kwargs"): + assert hasattr( + additional_kwargs, "get" + ), f"expected Mapping, found {type(additional_kwargs)}" + if function_call := additional_kwargs.get("function_call"): + assert hasattr( + function_call, "get" + ), f"expected Mapping, found {type(function_call)}" + if name := function_call.get("name"): + assert isinstance(name, str), f"expected str, found {type(name)}" + yield MESSAGE_FUNCTION_CALL_NAME, name + if arguments := function_call.get("arguments"): + assert isinstance(arguments, str), f"expected str, found {type(arguments)}" + yield MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, arguments + if tool_calls := additional_kwargs.get("tool_calls"): + assert isinstance( + tool_calls, Iterable + ), f"expected Iterable, found {type(tool_calls)}" + message_tool_calls = [] + for tool_call in tool_calls: + if message_tool_call := dict(_get_tool_call(tool_call)): + message_tool_calls.append(message_tool_call) + if message_tool_calls: + yield MESSAGE_TOOL_CALLS, message_tool_calls + + +@stop_on_exception +def _get_tool_call(tool_call: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, Any]]: + if not tool_call: + return + assert hasattr(tool_call, "get"), f"expected Mapping, found {type(tool_call)}" + if function := tool_call.get("function"): + assert hasattr(function, "get"), f"expected Mapping, found {type(function)}" + if name := function.get("name"): + assert isinstance(name, str), f"expected str, found {type(name)}" + yield TOOL_CALL_FUNCTION_NAME, name + if arguments := function.get("arguments"): + assert isinstance(arguments, str), f"expected str, found {type(arguments)}" + yield TOOL_CALL_FUNCTION_ARGUMENTS_JSON, arguments + + +@stop_on_exception +def _prompt_template(serialized: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, Any]]: + """ + A best-effort attempt to locate the PromptTemplate object among the + keyword arguments of a serialized object, e.g. an LLMChain object. + """ + if not serialized: + return + assert hasattr(serialized, "get"), f"expected Mapping, found {type(serialized)}" + if not (kwargs := serialized.get("kwargs")): + return + assert isinstance(kwargs, dict), f"expected dict, found {type(kwargs)}" + for obj in kwargs.values(): + if not hasattr(obj, "get") or (id_ := obj.get("id")): + continue + # The `id` field of the object is a list indicating the path to the + # object's class in the LangChain package, e.g. `PromptTemplate` in + # the `langchain.prompts.prompt` module is represented as + # ["langchain", "prompts", "prompt", "PromptTemplate"] + assert isinstance(id_, Sequence), f"expected list, found {type(id_)}" + if id_[-1].endswith("PromptTemplate"): + if not (kwargs := obj.get("kwargs")): + continue + assert hasattr(kwargs, "get"), f"expected Mapping, found {type(kwargs)}" + if not (template := kwargs.get("template", "")): + continue + yield LLM_PROMPT_TEMPLATE, template + if input_variables := kwargs.get("input_variables"): + assert isinstance( + input_variables, list + ), f"expected list, found {type(input_variables)}" + yield LLM_PROMPT_TEMPLATE_VARIABLES, input_variables + break + + +@stop_on_exception +def _invocation_parameters(run: Run) -> Iterator[Tuple[str, str]]: + """Yields invocation parameters if present.""" + if run.run_type.lower() != "llm": + return + if not (extra := run.extra): + return + assert hasattr(extra, "get"), f"expected Mapping, found {type(extra)}" + if invocation_parameters := extra.get("invocation_params"): + assert isinstance( + invocation_parameters, Mapping + ), f"expected Mapping, found {type(invocation_parameters)}" + yield LLM_INVOCATION_PARAMETERS, json.dumps(invocation_parameters) + + +@stop_on_exception +def _model_name(extra: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, str]]: + """Yields model name if present.""" + if not extra: + return + assert hasattr(extra, "get"), f"expected Mapping, found {type(extra)}" + if not (invocation_params := extra.get("invocation_params")): + return + for key in ["model_name", "model"]: + if name := invocation_params.get(key): + yield LLM_MODEL_NAME, name + return + + +@stop_on_exception +def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, int]]: + """Yields token count information if present.""" + if not outputs: + return + assert hasattr(outputs, "get"), f"expected Mapping, found {type(outputs)}" + if not (llm_output := outputs.get("llm_output")): + return + assert hasattr(llm_output, "get"), f"expected Mapping, found {type(llm_output)}" + if not (token_usage := llm_output.get("token_usage")): + return + assert hasattr(token_usage, "get"), f"expected Mapping, found {type(token_usage)}" + for attribute_name, key in [ + (LLM_TOKEN_COUNT_PROMPT, "prompt_tokens"), + (LLM_TOKEN_COUNT_COMPLETION, "completion_tokens"), + (LLM_TOKEN_COUNT_TOTAL, "total_tokens"), + ]: + if (token_count := token_usage.get(key)) is not None: + yield attribute_name, token_count + + +@stop_on_exception +def _function_calls(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, str]]: + """Yields function call information if present.""" + if not outputs: + return + assert hasattr(outputs, "get"), f"expected Mapping, found {type(outputs)}" + try: + function_call_data = deepcopy( + outputs["generations"][0][0]["message"]["kwargs"]["additional_kwargs"]["function_call"] + ) + function_call_data["arguments"] = json.loads(function_call_data["arguments"]) + yield LLM_FUNCTION_CALL, json.dumps(function_call_data) + except Exception: + pass + + +@stop_on_exception +def _tools(run: Run) -> Iterator[Tuple[str, str]]: + """Yields tool attributes if present.""" + if run.run_type.lower() != "tool": + return + if not (serialized := run.serialized): + return + assert hasattr(serialized, "get"), f"expected Mapping, found {type(serialized)}" + if name := serialized.get("name"): + yield TOOL_NAME, name + if description := serialized.get("description"): + yield TOOL_DESCRIPTION, description + + +@stop_on_exception +def _retrieval_documents(run: Run) -> Iterator[Tuple[str, List[Mapping[str, Any]]]]: + if run.run_type.lower() != "retriever": + return + if not (outputs := run.outputs): + return + assert hasattr(outputs, "get"), f"expected Mapping, found {type(outputs)}" + documents = outputs.get("documents") + assert isinstance(documents, Iterable), f"expected Iterable, found {type(documents)}" + yield RETRIEVAL_DOCUMENTS, [dict(_as_document(document)) for document in documents] + + +@stop_on_exception +def _as_document(document: Any) -> Iterator[Tuple[str, Any]]: + if page_content := getattr(document, "page_content", None): + assert isinstance(page_content, str), f"expected str, found {type(page_content)}" + yield DOCUMENT_CONTENT, page_content + if metadata := getattr(document, "metadata", None): + assert isinstance(metadata, Mapping), f"expected Mapping, found {type(metadata)}" + yield DOCUMENT_METADATA, metadata + + +DOCUMENT_CONTENT = DocumentAttributes.DOCUMENT_CONTENT +DOCUMENT_ID = DocumentAttributes.DOCUMENT_ID +DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA +DOCUMENT_SCORE = DocumentAttributes.DOCUMENT_SCORE +EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS +EMBEDDING_MODEL_NAME = SpanAttributes.EMBEDDING_MODEL_NAME +EMBEDDING_TEXT = EmbeddingAttributes.EMBEDDING_TEXT +EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE +INPUT_VALUE = SpanAttributes.INPUT_VALUE +LLM_FUNCTION_CALL = SpanAttributes.LLM_FUNCTION_CALL +LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES +LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS +LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME +LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES +LLM_PROMPTS = SpanAttributes.LLM_PROMPTS +LLM_PROMPT_TEMPLATE = SpanAttributes.LLM_PROMPT_TEMPLATE +LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL +MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT +MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON +MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME +MESSAGE_NAME = MessageAttributes.MESSAGE_NAME +MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE +MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS +OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE +RERANKER_INPUT_DOCUMENTS = RerankerAttributes.RERANKER_INPUT_DOCUMENTS +RERANKER_MODEL_NAME = RerankerAttributes.RERANKER_MODEL_NAME +RERANKER_OUTPUT_DOCUMENTS = RerankerAttributes.RERANKER_OUTPUT_DOCUMENTS +RERANKER_QUERY = RerankerAttributes.RERANKER_QUERY +RERANKER_TOP_K = RerankerAttributes.RERANKER_TOP_K +RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS +TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON +TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME +TOOL_DESCRIPTION = SpanAttributes.TOOL_DESCRIPTION +TOOL_NAME = SpanAttributes.TOOL_NAME +TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS diff --git a/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/package.py b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/package.py new file mode 100644 index 000000000..e2bf7cec0 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/package.py @@ -0,0 +1,2 @@ +_instruments = ("langchain_core >= 0.1.0",) +_supports_metrics = False diff --git a/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/py.typed b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/version.py b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/version.py new file mode 100644 index 000000000..3dc1f76bc --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/version.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/python/instrumentation/openinference-instrumentation-langchain/tests/openinference/instrumentation/langchain/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-langchain/tests/openinference/instrumentation/langchain/test_instrumentor.py new file mode 100644 index 000000000..c21e11c7f --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/tests/openinference/instrumentation/langchain/test_instrumentor.py @@ -0,0 +1,353 @@ +import asyncio +import logging +import random +from contextlib import suppress +from itertools import count +from typing import Any, AsyncIterator, Dict, Generator, Iterable, Iterator, List, Tuple + +import numpy as np +import openai +import pytest +from httpx import AsyncByteStream, Response, SyncByteStream +from langchain.chains import RetrievalQA +from langchain_community.embeddings import FakeEmbeddings +from langchain_community.retrievers import KNNRetriever +from langchain_openai import ChatOpenAI +from openinference.instrumentation.langchain import LangChainInstrumentor +from openinference.instrumentation.openai import OpenAIInstrumentor +from openinference.semconv.trace import ( + DocumentAttributes, + EmbeddingAttributes, + MessageAttributes, + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, + ToolCallAttributes, +) +from opentelemetry import trace as trace_api +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from respx import MockRouter + +for name, logger in logging.root.manager.loggerDict.items(): + if name.startswith("openinference.") and isinstance(logger, logging.Logger): + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + logger.addHandler(logging.StreamHandler()) + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.parametrize("is_stream", [False, True]) +@pytest.mark.parametrize("status_code", [200, 400]) +def test_callback_llm( + is_async: bool, + is_stream: bool, + status_code: int, + respx_mock: MockRouter, + in_memory_span_exporter: InMemorySpanExporter, + documents: List[str], + chat_completion_mock_stream: Tuple[List[bytes], List[Dict[str, Any]]], + model_name: str, + completion_usage: Dict[str, Any], +) -> None: + question = randstr() + output_messages: List[Dict[str, Any]] = ( + chat_completion_mock_stream[1] if is_stream else [{"role": randstr(), "content": randstr()}] + ) + url = "https://api.openai.com/v1/chat/completions" + respx_kwargs: Dict[str, Any] = { + **( + {"stream": MockByteStream(chat_completion_mock_stream[0])} + if is_stream + else { + "json": { + "choices": [ + {"index": i, "message": message, "finish_reason": "stop"} + for i, message in enumerate(output_messages) + ], + "model": model_name, + "usage": completion_usage, + } + } + ), + } + respx_mock.post(url).mock(return_value=Response(status_code=status_code, **respx_kwargs)) + chat_model = ChatOpenAI(model_name="gpt-3.5-turbo", streaming=is_stream) # type: ignore + retriever = KNNRetriever( + index=np.ones((len(documents), 2)), + texts=documents, + embeddings=FakeEmbeddings(size=2), + ) + rqa = RetrievalQA.from_chain_type(llm=chat_model, retriever=retriever) + with suppress(openai.BadRequestError): + if is_async: + asyncio.run(rqa.ainvoke({"query": question})) + else: + rqa.invoke({"query": question}) + + spans = in_memory_span_exporter.get_finished_spans() + spans_by_name = {span.name: span for span in spans} + + assert (rqa_span := spans_by_name.pop("RetrievalQA")) is not None + assert rqa_span.parent is None + rqa_attributes = dict(rqa_span.attributes or {}) + assert rqa_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == CHAIN.value + assert rqa_attributes.pop(INPUT_VALUE, None) == question + if status_code == 200: + assert rqa_span.status.status_code == trace_api.StatusCode.OK + assert rqa_attributes.pop(OUTPUT_VALUE, None) == output_messages[0]["content"] + elif status_code == 400: + assert rqa_span.status.status_code == trace_api.StatusCode.ERROR + assert rqa_span.events[0].name == "exception" + assert (rqa_span.events[0].attributes or {}).get( + OTELSpanAttributes.EXCEPTION_TYPE + ) == "BadRequestError" + assert rqa_attributes == {} + + assert (sd_span := spans_by_name.pop("StuffDocumentsChain")) is not None + assert sd_span.parent is not None + assert sd_span.parent.span_id == rqa_span.context.span_id + assert sd_span.context.trace_id == rqa_span.context.trace_id + sd_attributes = dict(sd_span.attributes or {}) + assert sd_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == CHAIN.value + assert sd_attributes.pop(INPUT_VALUE, None) is not None + assert sd_attributes.pop(INPUT_MIME_TYPE, None) == JSON.value + if status_code == 200: + assert sd_span.status.status_code == trace_api.StatusCode.OK + assert sd_attributes.pop(OUTPUT_VALUE, None) == output_messages[0]["content"] + elif status_code == 400: + assert sd_span.status.status_code == trace_api.StatusCode.ERROR + assert sd_span.events[0].name == "exception" + assert (sd_span.events[0].attributes or {}).get( + OTELSpanAttributes.EXCEPTION_TYPE + ) == "BadRequestError" + assert sd_attributes == {} + + assert (retriever_span := spans_by_name.pop("Retriever")) is not None + assert retriever_span.parent is not None + assert retriever_span.parent.span_id == rqa_span.context.span_id + assert retriever_span.context.trace_id == rqa_span.context.trace_id + retriever_attributes = dict(retriever_span.attributes or {}) + assert retriever_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == RETRIEVER.value + assert retriever_attributes.pop(INPUT_VALUE, None) == question + assert retriever_attributes.pop(OUTPUT_VALUE, None) is not None + assert retriever_attributes.pop(OUTPUT_MIME_TYPE, None) == JSON.value + for i, text in enumerate(documents): + assert ( + retriever_attributes.pop(f"{RETRIEVAL_DOCUMENTS}.{i}.{DOCUMENT_CONTENT}", None) == text + ) + assert retriever_attributes == {} + + assert (llm_span := spans_by_name.pop("LLMChain", None)) is not None + assert llm_span.parent is not None + assert llm_span.parent.span_id == sd_span.context.span_id + assert llm_span.context.trace_id == sd_span.context.trace_id + llm_attributes = dict(llm_span.attributes or {}) + assert llm_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == CHAIN.value + assert llm_attributes.pop(INPUT_VALUE, None) is not None + assert llm_attributes.pop(INPUT_MIME_TYPE, None) == JSON.value + if status_code == 200: + assert llm_attributes.pop(OUTPUT_VALUE, None) == output_messages[0]["content"] + elif status_code == 400: + assert llm_span.status.status_code == trace_api.StatusCode.ERROR + assert llm_span.events[0].name == "exception" + assert (llm_span.events[0].attributes or {}).get( + OTELSpanAttributes.EXCEPTION_TYPE + ) == "BadRequestError" + assert llm_attributes == {} + + assert (oai_span := spans_by_name.pop("ChatOpenAI", None)) is not None + assert oai_span.parent is not None + assert oai_span.parent.span_id == llm_span.context.span_id + assert oai_span.context.trace_id == llm_span.context.trace_id + oai_attributes = dict(oai_span.attributes or {}) + assert oai_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == LLM.value + assert oai_attributes.pop(LLM_MODEL_NAME, None) is not None + assert oai_attributes.pop(LLM_INVOCATION_PARAMETERS, None) is not None + assert oai_attributes.pop(INPUT_VALUE, None) is not None + assert oai_attributes.pop(INPUT_MIME_TYPE, None) == JSON.value + assert oai_attributes.pop(LLM_PROMPTS, None) is not None + if status_code == 200: + assert oai_span.status.status_code == trace_api.StatusCode.OK + assert oai_attributes.pop(OUTPUT_VALUE, None) is not None + assert oai_attributes.pop(OUTPUT_MIME_TYPE, None) == JSON.value + assert ( + oai_attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", None) + == output_messages[0]["role"] + ) + assert ( + oai_attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", None) + == output_messages[0]["content"] + ) + if not is_stream: + assert ( + oai_attributes.pop(LLM_TOKEN_COUNT_TOTAL, None) == completion_usage["total_tokens"] + ) + assert ( + oai_attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) + == completion_usage["prompt_tokens"] + ) + assert ( + oai_attributes.pop(LLM_TOKEN_COUNT_COMPLETION, None) + == completion_usage["completion_tokens"] + ) + elif status_code == 400: + assert oai_span.status.status_code == trace_api.StatusCode.ERROR + assert oai_span.events[0].name == "exception" + assert (oai_span.events[0].attributes or {}).get( + OTELSpanAttributes.EXCEPTION_TYPE + ) == "BadRequestError" + assert oai_attributes == {} + + # The remaining span is from the openai instrumentor. + openai_span = spans_by_name.popitem()[1] + assert openai_span.parent is not None + if is_async: + # FIXME: it's unclear why the context fails to propagate. + assert openai_span.parent.span_id == llm_span.context.span_id + assert openai_span.context.trace_id == llm_span.context.trace_id + else: + assert openai_span.parent.span_id == oai_span.context.span_id + assert openai_span.context.trace_id == oai_span.context.trace_id + + assert spans_by_name == {} + + +@pytest.fixture +def documents() -> List[str]: + return [randstr(), randstr()] + + +@pytest.fixture +def chat_completion_mock_stream() -> Tuple[List[bytes], List[Dict[str, Any]]]: + return ( + [ + b'data: {"choices": [{"delta": {"role": "assistant"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"content": "A"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"content": "B"}, "index": 0}]}\n\n', + b'data: {"choices": [{"delta": {"content": "C"}, "index": 0}]}\n\n', + b"data: [DONE]\n", + ], + [{"role": "assistant", "content": "ABC"}], + ) + + +@pytest.fixture(scope="module") +def in_memory_span_exporter() -> InMemorySpanExporter: + return InMemorySpanExporter() + + +@pytest.fixture(scope="module") +def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> trace_api.TracerProvider: + resource = Resource(attributes={}) + tracer_provider = trace_sdk.TracerProvider(resource=resource) + span_processor = SimpleSpanProcessor(span_exporter=in_memory_span_exporter) + tracer_provider.add_span_processor(span_processor=span_processor) + return tracer_provider + + +@pytest.fixture(autouse=True) +def instrument( + tracer_provider: trace_api.TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, +) -> Generator[None, None, None]: + LangChainInstrumentor().instrument(tracer_provider=tracer_provider) + OpenAIInstrumentor().instrument(tracer_provider=tracer_provider) + yield + OpenAIInstrumentor().uninstrument() + LangChainInstrumentor().uninstrument() + in_memory_span_exporter.clear() + + +@pytest.fixture(autouse=True) +def openai_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "sk-") + + +@pytest.fixture(scope="module") +def seed() -> Iterator[int]: + """ + Use rolling seeds to help debugging, because the rolling pseudo-random values + allow conditional breakpoints to be hit precisely (and repeatably). + """ + return count() + + +@pytest.fixture(autouse=True) +def set_seed(seed: Iterator[int]) -> Iterator[None]: + random.seed(next(seed)) + yield + + +@pytest.fixture +def completion_usage() -> Dict[str, Any]: + prompt_tokens = random.randint(1, 1000) + completion_tokens = random.randint(1, 1000) + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + + +@pytest.fixture +def model_name() -> str: + return randstr() + + +def randstr() -> str: + return str(random.random()) + + +class MockByteStream(SyncByteStream, AsyncByteStream): + def __init__(self, byte_stream: Iterable[bytes]): + self._byte_stream = byte_stream + + def __iter__(self) -> Iterator[bytes]: + for byte_string in self._byte_stream: + yield byte_string + + async def __aiter__(self) -> AsyncIterator[bytes]: + for byte_string in self._byte_stream: + yield byte_string + + +DOCUMENT_CONTENT = DocumentAttributes.DOCUMENT_CONTENT +DOCUMENT_ID = DocumentAttributes.DOCUMENT_ID +DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA +EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS +EMBEDDING_MODEL_NAME = SpanAttributes.EMBEDDING_MODEL_NAME +EMBEDDING_TEXT = EmbeddingAttributes.EMBEDDING_TEXT +EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE +INPUT_VALUE = SpanAttributes.INPUT_VALUE +LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES +LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS +LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME +LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES +LLM_PROMPTS = SpanAttributes.LLM_PROMPTS +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL +MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT +MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON +MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME +MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE +MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS +OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE +RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS +TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON +TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME +LLM_PROMPT_TEMPLATE = SpanAttributes.LLM_PROMPT_TEMPLATE +LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES + +CHAIN = OpenInferenceSpanKindValues.CHAIN +LLM = OpenInferenceSpanKindValues.LLM +RETRIEVER = OpenInferenceSpanKindValues.RETRIEVER + +JSON = OpenInferenceMimeTypeValues.JSON diff --git a/python/tox.ini b/python/tox.ini index 7586c931e..9f26ad13d 100644 --- a/python/tox.ini +++ b/python/tox.ini @@ -7,6 +7,8 @@ envlist = py3{8,11}-ci-{openai,openai-latest} py3{8,11}-ci-{llama_index,llama_index-latest} py3{9,11}-ci-{dspy,dspy-latest} + py3{8,11}-ci-{langchain,langchain-latest} + py38-mypy-langchain_core [testenv] package = wheel @@ -19,6 +21,8 @@ changedir = openai: instrumentation/openinference-instrumentation-openai/ llama_index: instrumentation/openinference-instrumentation-llama-index/ dspy: instrumentation/openinference-instrumentation-dspy/ + langchain: instrumentation/openinference-instrumentation-langchain/ + langchain_core: instrumentation/openinference-instrumentation-langchain/src commands_pre = semconv: pip install {toxinidir}/openinference-semantic-conventions bedrock: pip install {toxinidir}/instrumentation/openinference-instrumentation-bedrock[test] @@ -29,6 +33,9 @@ commands_pre = llama_index-latest: pip install -U llama_index dspy: pip install {toxinidir}/instrumentation/openinference-instrumentation-dspy[test] dspy-latest: pip install -U dspy-ai + langchain: pip install {toxinidir}/instrumentation/openinference-instrumentation-langchain[test] + langchain-latest: pip install -U langchain langchain_core langchain_openai langchain_community + langchain_core: pip install {toxinidir}/instrumentation/openinference-instrumentation-langchain[type-check] commands = ruff: ruff format . --config {toxinidir}/ruff.toml ruff: ruff check . --fix --config {toxinidir}/ruff.toml diff --git a/release-please-config.json b/release-please-config.json index a46605695..383b03f11 100644 --- a/release-please-config.json +++ b/release-please-config.json @@ -25,6 +25,10 @@ "python/instrumentation/openinference-instrumentation-bedrock": { "package-name": "python-openinference-instrumentation-bedrock", "release-type": "python" + }, + "python/instrumentation/openinference-instrumentation-langchain": { + "package-name": "python-openinference-instrumentation-langchain", + "release-type": "python" } } }