Skip to content

Commit

Permalink
fix(llama-index): capture tool calls from anthropic chat response (#1177
Browse files Browse the repository at this point in the history
)
  • Loading branch information
RogerHYang authored Dec 20, 2024
1 parent 8ca8826 commit e1ba6a5
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ test = [
"llama-index == 0.11.0",
"llama-index-core >= 0.11.0",
"llama-index-llms-openai",
"llama_index.llms.anthropic",
"llama-index-llms-groq",
"pytest-vcr",
"anthropic<0.41",
"llama-index-multi-modal-llms-openai>=0.1.7",
"openinference-instrumentation-openai",
"opentelemetry-sdk",
Expand Down Expand Up @@ -86,6 +88,8 @@ exclude = [
ignore_missing_imports = true
module = [
"wrapt",
"llama_index.llms.anthropic",
"llama_index.llms.openai",
]

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,19 @@ def handle(self, event: BaseEvent, **kwargs: Any) -> Any:


def _get_tool_call(tool_call: object) -> Iterator[Tuple[str, Any]]:
if function := getattr(tool_call, "function", None):
if isinstance(tool_call, dict):
if tool_call_id := tool_call.get("id"):
yield TOOL_CALL_ID, tool_call_id
if name := tool_call.get("name"):
yield TOOL_CALL_FUNCTION_NAME, name
if arguments := tool_call.get("input"):
if isinstance(arguments, str):
yield TOOL_CALL_FUNCTION_ARGUMENTS_JSON, arguments
elif isinstance(arguments, dict):
yield TOOL_CALL_FUNCTION_ARGUMENTS_JSON, safe_json_dumps(arguments)
elif function := getattr(tool_call, "function", None):
if tool_call_id := getattr(tool_call, "id", None):
yield TOOL_CALL_ID, tool_call_id
if name := getattr(function, "name", None):
yield TOOL_CALL_FUNCTION_NAME, name
if arguments := getattr(function, "arguments", None):
Expand Down Expand Up @@ -1032,6 +1044,7 @@ def is_base64_url(url: str) -> bool:
RERANKER_QUERY = RerankerAttributes.RERANKER_QUERY
RERANKER_TOP_K = RerankerAttributes.RERANKER_TOP_K
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
TOOL_DESCRIPTION = SpanAttributes.TOOL_DESCRIPTION
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
interactions:
- request:
body: '{"max_tokens":512,"messages":[{"role":"user","content":[{"text":"what''s
the weather in San Francisco?","type":"text"}]}],"model":"claude-3-5-haiku-20241022","stream":false,"system":"","temperature":0.1,"tools":[{"name":"get_weather","description":"get_weather(location:
str) -> str\nUseful for getting the weather for a given location.","input_schema":{"properties":{"location":{"title":"Location","type":"string"}},"required":["location"],"type":"object"}}]}'
headers: {}
method: POST
uri: https://api.anthropic.com/v1/messages
response:
body:
string: '{"id":"msg_011UbtsepYnQzWFNg8fDmFZ2","type":"message","role":"assistant","model":"claude-3-5-haiku-20241022","content":[{"type":"text","text":"I''ll
help you check the weather in San Francisco right away."},{"type":"tool_use","id":"toolu_01P7dMjNQjMNZK8BB8sKP25k","name":"get_weather","input":{"location":"San
Francisco"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":355,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":68}}'
headers: {}
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
interactions:
- request:
body: '{"messages":[{"role":"user","content":"what''s the weather in San Francisco?"}],"model":"gpt-4o-mini","stream":false,"temperature":0.1,"tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"get_weather(location:
str) -> str\nUseful for getting the weather for a given location.","parameters":{"properties":{"location":{"title":"Location","type":"string"}},"required":["location"],"type":"object","additionalProperties":false},"strict":false}}]}'
headers: {}
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: "{\n \"id\": \"chatcmpl-AgcDNYhR5NPYhy2hmtnkm6CP8GFAN\",\n \"object\":
\"chat.completion\",\n \"created\": 1734720037,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
\ \"id\": \"call_FjpIANozIfaXzuQQnmhK0yD3\",\n \"type\":
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
\ \"arguments\": \"{\\\"location\\\":\\\"San Francisco\\\"}\"\n
\ }\n }\n ],\n \"refusal\": null\n },\n
\ \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n }\n
\ ],\n \"usage\": {\n \"prompt_tokens\": 68,\n \"completion_tokens\":
16,\n \"total_tokens\": 84,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\":
\"fp_0aa8d3e20b\"\n}\n"
headers: {}
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from llama_index.core.base.response.schema import StreamingResponse
from llama_index.core.callbacks import CallbackManager
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from opentelemetry import trace as trace_api
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace import ReadableSpan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from llama_index.core import Document, ListIndex, Settings
from llama_index.core.callbacks import CallbackManager
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from opentelemetry import trace as trace_api
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace import ReadableSpan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from httpx import Response
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.multi_modal_llms.generic_utils import load_image_urls
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal # type: ignore
from llama_index.multi_modal_llms.openai import utils as openai_utils
from opentelemetry import trace as trace_api
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.multi_modal_llms.generic_utils import load_image_urls
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI # type: ignore
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal # type: ignore
from llama_index.multi_modal_llms.openai import utils as openai_utils
from opentelemetry import trace as trace_api
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from importlib.metadata import version
from json import loads
from typing import Iterator, Tuple, cast

import pytest
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.tools import FunctionTool
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.openai import OpenAI
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import TracerProvider

from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes

LLAMA_INDEX_LLMS_OPENAI_VERSION = cast(
Tuple[int, int], tuple(map(int, version("llama_index.llms.openai").split(".")[:2]))
)
LLAMA_INDEX_LLMS_ANTHROPIC_VERSION = cast(
Tuple[int, int], tuple(map(int, version("llama_index.llms.anthropic").split(".")[:2]))
)


def get_weather(location: str) -> str:
"""Useful for getting the weather for a given location."""
raise NotImplementedError


TOOL = FunctionTool.from_defaults(get_weather)


class TestToolCallsInChatResponse:
@pytest.mark.skipif(
LLAMA_INDEX_LLMS_OPENAI_VERSION < (0, 3),
reason="ignore older versions to simplify test upkeep",
)
@pytest.mark.vcr(
decode_compressed_response=True,
before_record_request=lambda _: _.headers.clear() or _,
before_record_response=lambda _: {**_, "headers": {}},
)
async def test_openai(
self,
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
llm = OpenAI(model="gpt-4o-mini", api_key="sk-")
await self._test(llm, in_memory_span_exporter)

@pytest.mark.skipif(
LLAMA_INDEX_LLMS_ANTHROPIC_VERSION < (0, 6),
reason="ignore older versions to simplify test upkeep",
)
@pytest.mark.vcr(
decode_compressed_response=True,
before_record_request=lambda _: _.headers.clear() or _,
before_record_response=lambda _: {**_, "headers": {}},
)
async def test_anthropic(
self,
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
llm = Anthropic(model="claude-3-5-haiku-20241022", api_key="sk-")
await self._test(llm, in_memory_span_exporter)

@classmethod
async def _test(
cls,
llm: FunctionCallingLLM,
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
await llm.achat(
**llm._prepare_chat_with_tools([TOOL], "what's the weather in San Francisco?"),
)
spans = in_memory_span_exporter.get_finished_spans()
span = spans[-1]
assert span.attributes
assert span.attributes.get(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_ID}")
assert (
span.attributes.get(
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_NAME}"
)
== "get_weather"
)
assert isinstance(
arguments := span.attributes.get(
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}"
),
str,
)
assert loads(arguments) == {"location": "San Francisco"}


@pytest.fixture(autouse=True)
def instrument(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
) -> Iterator[None]:
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
yield
LlamaIndexInstrumentor().uninstrument()


LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
MESSAGE_TOOL_CALL_ID = MessageAttributes.MESSAGE_TOOL_CALL_ID
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
2 changes: 1 addition & 1 deletion python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ commands_pre =
vertexai: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-vertexai[test]
vertexai-latest: uv pip install -U vertexai 'httpx<0.28'
llama_index: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-llama-index[test] 'httpx<0.28'
llama_index-latest: uv pip install -U llama-index llama-index-core 'httpx<0.28'
llama_index-latest: uv pip install -U llama-index llama-index-core llama-index-llms-openai openai llama-index-llms-anthropic anthropic 'httpx<0.28'
dspy: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-dspy[test]
dspy-latest: uv pip install -U dspy-ai 'httpx<0.28'
langchain: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-langchain[test]
Expand Down

0 comments on commit e1ba6a5

Please sign in to comment.