Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama payload #211

Merged
merged 6 commits into from
Aug 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 103 additions & 61 deletions portkey_ai/llamaindex/portkey_llama_callback_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from enum import Enum
from enum import Enum, auto
import json
import time
from typing import Any, Dict, List, Optional
from portkey_ai.api_resources.apis.logger import Logger
from datetime import datetime
from llama_index.core.callbacks.schema import (
CBEventType,
)
from uuid import uuid4
from llama_index.legacy.schema import NodeRelationship

try:
from llama_index.core.callbacks.base_handler import (
Expand All @@ -27,25 +25,30 @@ class LlamaIndexCallbackHandler(LlamaIndexBaseCallbackHandler):
def __init__(
self,
api_key: str,
metadata: Optional[Dict[str, Any]] = {},
metadata: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
event_starts_to_ignore=[
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.SYNTHESIZE,
CBEventType.EXCEPTION,
CBEventType.TREE,
CBEventType.RERANKING,
],
event_ends_to_ignore=[
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.SYNTHESIZE,
CBEventType.EXCEPTION,
CBEventType.TREE,
CBEventType.RERANKING,
],
)

self.api_key = api_key
self.metadata = metadata
self.metadata: Dict[str, Any] = metadata or {}
self.metadata.update({"_source": "LlamaIndex", "_source_type": "Agent"})

self.portkey_logger = Logger(api_key=api_key)

Expand Down Expand Up @@ -83,7 +86,7 @@ def on_event_start( # type: ignore
span_id = str(event_id)
parent_span_id = parent_id
span_name = event_type
start_time = int(datetime.now().timestamp())
start_time = time.time()

if parent_id == "root":
parent_span_id = self.main_span_id
Expand All @@ -102,15 +105,18 @@ def on_event_start( # type: ignore
request_payload = self.retrieve_event_start(payload)
elif event_type == "templating":
request_payload = self.templating_event_start(payload)
elif event_type == "sub_question":
request_payload = self.sub_question_event_start(payload)
else:
request_payload = payload
return ""

start_event_information = {
"span_id": span_id,
"parent_span_id": parent_span_id,
"span_name": span_name.value,
"trace_id": self.global_trace_id,
"request": request_payload,
"event_type": event_type,
"start_time": start_time,
"metadata": self.metadata,
}
Expand All @@ -131,8 +137,8 @@ def on_event_end(
if span_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"
response_payload["response_time"] = total_time
else:
if event_type == "llm":
Expand All @@ -149,8 +155,10 @@ def on_event_end(
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
elif event_type == "sub_question":
response_payload = self.sub_question_event_end(payload, event_id)
else:
response_payload = payload
return

self.event_map[span_id]["response"] = response_payload

Expand Down Expand Up @@ -202,32 +210,26 @@ def llm_event_start(self, payload: Any) -> Any:
return self.request

def llm_event_end(self, payload: Any, event_id) -> Any:
result: Dict[str, Any] = {}
result["body"] = {}

try:
data = self.serialize(payload)
except Exception:
data = payload.__dict__

if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]

self.response = {}

data = payload.get("response", {})
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

chunks = payload.get("messages", {})
self.completion_tokens = self._token_counter.estimate_tokens_in_messages(chunks)
self.token_llm = self.prompt_tokens + self.completion_tokens
self.response["status"] = 200
self.response["body"] = {
"choices": [
{
"index": 0,
"message": {
"role": data.message.role.value,
"content": data.message.content,
},
"logprobs": data.logprobs,
"finish_reason": "done",
}
]
}
self.response["body"].update(

result["body"] = data["response"]
result["body"].update(
{
"usage": {
"prompt_tokens": self.prompt_tokens,
Expand All @@ -236,18 +238,16 @@ def llm_event_end(self, payload: Any, event_id) -> Any:
}
}
)
self.response["body"].update({"id": event_id})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": getattr(data, "model", "")})
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
result["body"].update({"id": event_id})
result["body"].update({"created": int(time.time())})
result["body"].update({"model": getattr(data, "model", "")})
result["streamingMode"] = self.streamingMode

self.response["response_time"] = total_time
result["status"] = 200
result["headers"] = {}
result["response_time"] = total_time

return self.response
return result

# ------------------------------------------------------ #
def embedding_event_start(self, payload: Any) -> Any:
Expand All @@ -270,9 +270,9 @@ def embedding_event_start(self, payload: Any) -> Any:
def embedding_event_end(self, payload: Any, event_id) -> Any:
if event_id in self.event_map:
event = self.event_map[event_id]
# event["request"]["body"]["input"] = payload.get("chunks", "")
event["request"]["body"]["input"] = payload.get("chunks", "")
# Setting as ...INPUT... to avoid logging the entire data input file
event["request"]["body"]["input"] = "...INPUT..."
# event["request"]["body"]["input"] = "...INPUT..."

start_time = event["start_time"]

Expand Down Expand Up @@ -302,27 +302,29 @@ def embedding_event_end(self, payload: Any, event_id) -> Any:
)
self.response["headers"] = {}

end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

self.response["response_time"] = total_time

return self.response

# ------------------------------------------------------ #
def agent_step_event_start(self, payload: Any) -> Any:
data = json.dumps(self.serialize(payload))
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)
return data

def agent_step_event_end(self, payload: Any, event_id) -> Any:
if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
data = self.serialize(payload)
json.dumps(data)
result = self.transform_agent_step_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result
Expand All @@ -337,35 +339,39 @@ def function_call_event_end(self, payload: Any, event_id) -> Any:
event = self.event_map[event_id]
start_time = event["start_time"]
data = self.serialize(payload)
json.dumps(data)
result = self.transform_function_call_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #
def query_event_start(self, payload: Any) -> Any:
data = json.dumps(self.serialize(payload))
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)
return data

def query_event_end(self, payload: Any, event_id) -> Any:
if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
data = self.serialize(payload)
json.dumps(data)
result = self.transform_query_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #
def retrieve_event_start(self, payload: Any) -> Any:
data = json.dumps(self.serialize(payload))
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)
return data

def retrieve_event_end(self, payload: Any, event_id) -> Any:
Expand All @@ -374,18 +380,16 @@ def retrieve_event_end(self, payload: Any, event_id) -> Any:
start_time = event["start_time"]

data = self.serialize(payload)
json.dumps(data)
result = self.transform_retrieve_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #
def templating_event_start(self, payload: Any) -> Any:
data = self.serialize(payload)
json.dumps(data)
result = self.transform_templating_start(data)
return result

Expand All @@ -395,14 +399,44 @@ def templating_event_end(self, payload: Any, event_id) -> Any:
start_time = event["start_time"]
result = self.transform_templating_end(event_id)

end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #

def sub_question_event_start(self, payload: Any) -> Any:
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)

return data

def sub_question_event_end(self, payload: Any, event_id) -> Any:
result: Dict[str, Any] = {}
result["body"] = {}
if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]

try:
data = self.serialize(payload)
except Exception:
data = payload.__dict__

end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["body"] = data
result["body"]["response_time"] = total_time

return result

# ------------------------------------------------------ #

# ----------------- EVENT Transformers ----------------- #
def transform_agent_step_end(self, data: Any) -> Any:
try:
Expand Down Expand Up @@ -587,3 +621,11 @@ def serialize(self, obj):
if isinstance(obj, tuple):
return tuple(self.serialize(item) for item in obj)
return obj


class NodeRelationship(str, Enum):
SOURCE = auto()
PREVIOUS = auto()
NEXT = auto()
PARENT = auto()
CHILD = auto()
Loading