Skip to content

Commit

Permalink
update attr extraction, remove NOT_GIVEN values, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cjunkin committed Dec 20, 2024
1 parent d580c58 commit 8ae6822
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from groq import AsyncGroq, Groq
from groq.types.chat import ChatCompletionToolMessageParam
from phoenix.otel import register
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

from openinference.instrumentation.groq import GroqInstrumentor

Expand Down Expand Up @@ -107,7 +109,10 @@ async def async_test():


if __name__ == "__main__":
tracer_provider = register(project_name="groq_debug")
endpoint = "http://0.0.0.0:6006/v1/traces"
tracer_provider = trace_sdk.TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

GroqInstrumentor().instrument(tracer_provider=tracer_provider)

response = test()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import logging
from enum import Enum
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar
from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar

from opentelemetry.util.types import AttributeValue

from groq.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall
from groq.types.chat.chat_completion_message_tool_call import Function
from openinference.instrumentation import safe_json_dumps
from openinference.instrumentation.groq._utils import _as_input_attributes, _io_value_and_type
from openinference.semconv.trace import (
Expand Down Expand Up @@ -69,96 +67,65 @@ def _get_attributes_from_message_param(
self,
message: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if not hasattr(message, "get"):
if isinstance(message, ChatCompletionMessage):
message = self._cast_chat_completion_to_mapping(message)
else:
return
if role := message.get("role"):
if role := get_attribute(message, "role"):
yield (
MessageAttributes.MESSAGE_ROLE,
role.value if isinstance(role, Enum) else role,
)

if content := message.get("content"):
if content := get_attribute(message, "content"):
yield (
MessageAttributes.MESSAGE_CONTENT,
content,
)

if name := message.get("name"):
if name := get_attribute(message, "name"):
yield MessageAttributes.MESSAGE_NAME, name

if tool_call_id := message.get("tool_call_id"):
if tool_call_id := get_attribute(message, "tool_call_id"):
yield MessageAttributes.MESSAGE_TOOL_CALL_ID, tool_call_id

# Deprecated by Groq
if (function_call := message.get("function_call")) and hasattr(function_call, "get"):
if function_name := function_call.get("name"):
if function_call := get_attribute(message, "function_call"):
if function_name := get_attribute(function_call, "name"):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name
if function_arguments := function_call.get("arguments"):
if function_arguments := get_attribute(function_call, "arguments"):
yield (
MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON,
function_arguments,
)

if (tool_calls := message.get("tool_calls"),) and isinstance(tool_calls, Iterable):
if (tool_calls := get_attribute(message, "tool_calls")) and isinstance(
tool_calls, Iterable
):
for index, tool_call in enumerate(tool_calls):
if not hasattr(tool_call, "get"):
continue
if (tool_call_id := tool_call.get("id")) is not None:
if (tool_call_id := get_attribute(tool_call, "id")) is not None:
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_ID}",
tool_call_id,
)
if (function := tool_call.get("function")) and hasattr(function, "get"):
if name := function.get("name"):
if function := get_attribute(tool_call, "function"):
if name := get_attribute(function, "name"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
name,
)
if arguments := function.get("arguments"):
if arguments := get_attribute(function, "arguments"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
arguments,
)

def _cast_chat_completion_to_mapping(self, message: ChatCompletionMessage) -> Mapping[str, Any]:
try:
casted_message = dict(message)
if (tool_calls := casted_message.get("tool_calls")) and isinstance(
tool_calls, Iterable
):
casted_tool_calls: List[Dict[str, Any]] = []
for tool_call in tool_calls:
if isinstance(tool_call, ChatCompletionMessageToolCall):
tool_call_dict = dict(tool_call)

if (function := tool_call_dict.get("function")) and isinstance(
function, Function
):
tool_call_dict["function"] = dict(function)

casted_tool_calls.append(tool_call_dict)
else:
logger.debug(f"Skipping tool_call of unexpected type: {type(tool_call)}")

casted_message["tool_calls"] = casted_tool_calls

return casted_message

except Exception as e:
logger.exception(
f"Failed to convert ChatCompletionMessage to mapping for {message}: {e}"
)
return {}


T = TypeVar("T", bound=type)


def is_iterable_of(lst: Iterable[object], tp: T) -> bool:
return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)


def get_attribute(obj: Any, attr_name: str, default: Any = None) -> Any:
if isinstance(obj, dict):
return obj.get(attr_name, default)
return getattr(obj, attr_name, default)
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@


class _ResponseAttributesExtractor:
__slots__ = ()

def get_attributes(self, response: Any) -> Iterator[Tuple[str, AttributeValue]]:
yield from _as_output_attributes(
_io_value_and_type(response),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from opentelemetry.trace import INVALID_SPAN
from opentelemetry.util.types import AttributeValue

from groq import NOT_GIVEN
from openinference.instrumentation import get_attributes_from_context, safe_json_dumps
from openinference.instrumentation.groq._request_attributes_extractor import (
_RequestAttributesExtractor,
Expand Down Expand Up @@ -93,11 +94,11 @@ def _parse_args(
) -> Dict[str, Any]:
bound_signature = signature.bind(*args, **kwargs)
bound_signature.apply_defaults()
bound_arguments = bound_signature.arguments
bound_arguments = bound_signature.arguments # Defaults empty to NOT_GIVEN
request_data: Dict[str, Any] = {}
for key, value in bound_arguments.items():
try:
if value is not None:
if value is not None and value is not NOT_GIVEN:
try:
# ensure the value is JSON-serializable
safe_json_dumps(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from openinference.instrumentation import using_attributes


def _mock_post(
self: Any,
Expand Down Expand Up @@ -184,51 +182,51 @@ def test_tool_calls(
),
]

client.chat.completions.create(
model="test_groq_model",
tools=input_tools,
messages=[
{
"role": "assistant",
"tool_calls": [
{
"id": "call_62136355",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "New York"}',
},
client.chat.completions.create(
model="test_groq_model",
tools=input_tools,
messages=[
{
"role": "assistant",
"tool_calls": [
{
"id": "call_62136355",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "New York"}',
},
{
"id": "call_62136356",
"type": "function",
"function": {
"name": "get_population",
"arguments": '{"city": "New York"}',
},
},
{
"id": "call_62136356",
"type": "function",
"function": {
"name": "get_population",
"arguments": '{"city": "New York"}',
},
],
},
{
"role": "tool",
"tool_call_id": "call_62136355",
"content": '{"city": "New York", "weather": "fine"}',
},
{
"role": "tool",
"tool_call_id": "call_62136356",
"content": '{"city": "New York", "weather": "large"}',
},
{
"role": "assistant",
"content": "In New York the weather is fine and the population is large.",
},
{
"role": "user",
"content": "What's the weather and population in San Francisco?",
},
],
)
},
],
},
{
"role": "tool",
"tool_call_id": "call_62136355",
"content": '{"city": "New York", "weather": "fine"}',
},
{
"role": "tool",
"tool_call_id": "call_62136356",
"content": '{"city": "New York", "weather": "large"}',
},
{
"role": "assistant",
"content": "In New York the weather is fine and the population is large.",
},
{
"role": "user",
"content": "What's the weather and population in San Francisco?",
},
],
)
spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
Expand Down

0 comments on commit 8ae6822

Please sign in to comment.