Skip to content

Commit

Permalink
feat: metadata accpeted for filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
csgulati09 committed Aug 12, 2024
1 parent 5505385 commit 6693afe
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
26 changes: 22 additions & 4 deletions portkey_ai/llms/langchain/portkey_langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ class PortkeyLangchain(BaseCallbackHandler):
def __init__(
self,
api_key: str,
metadata: Optional[Dict[str, Any]] = {},
) -> None:
super().__init__()

self.api_key = api_key
self.metadata = metadata

self.portkey_logger = Logger(api_key=api_key)

Expand Down Expand Up @@ -69,6 +71,7 @@ def on_llm_start(
"llm_start",
self.global_trace_id,
request_payload,
self.metadata,
)
self.event_map["llm_start_" + str(run_id)] = info_obj
pass
Expand Down Expand Up @@ -101,6 +104,7 @@ def on_chat_model_start(
"chat_model_start",
self.global_trace_id,
request_payload,
self.metadata,
)
self.event_map["chat_model_start_" + str(run_id)] = info_obj
self.event_array.append(self.event_map["chat_model_start_" + str(run_id)])
Expand Down Expand Up @@ -147,9 +151,8 @@ def on_chain_start(
"""Run when chain starts running."""

if parent_run_id is None:
self.global_trace_id = str(uuid4())
self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr]
self.main_span_id = ""

parent_span_id = (
self.main_span_id if parent_run_id is None else str(parent_run_id)
)
Expand All @@ -163,6 +166,7 @@ def on_chain_start(
"chain_start",
self.global_trace_id,
request_payload,
self.metadata,
)

self.event_map["chain_start_" + str(run_id)] = info_obj
Expand Down Expand Up @@ -221,6 +225,7 @@ def on_tool_start(
"tool_start",
self.global_trace_id,
request_payload,
self.metadata,
)
self.event_map["tool_start_" + str(run_id)] = info_obj
pass
Expand Down Expand Up @@ -264,7 +269,12 @@ def on_text( # Do we need to log this or not? This is just formatting of the te
)
request_payload = self.on_text_transformer(text)
info_obj = self.start_event_information(
run_id, parent_span_id, "text", self.global_trace_id, request_payload
run_id,
parent_span_id,
"text",
self.global_trace_id,
request_payload,
self.metadata,
)
self.event_map["text_" + str(run_id)] = info_obj
self.event_array.append(self.event_map["text_" + str(run_id)])
Expand All @@ -291,6 +301,7 @@ def on_agent_action(
"agent_action",
self.global_trace_id,
request_payload,
self.metadata,
)
self.event_map["agent_action_" + str(run_id)] = info_obj
pass
Expand Down Expand Up @@ -347,7 +358,13 @@ def on_retriever_end(

# -------------- Helpers ------------------------------------------
def start_event_information(
self, span_id, parent_span_id, span_name, trace_id, request_payload
self,
span_id,
parent_span_id,
span_name,
trace_id,
request_payload,
metadata=None,
):
start_time = int(datetime.now().timestamp())
return {
Expand All @@ -357,6 +374,7 @@ def start_event_information(
"trace_id": trace_id,
"request": request_payload,
"start_time": start_time,
"metadata": metadata,
}

def serialize(self, obj):
Expand Down
6 changes: 5 additions & 1 deletion portkey_ai/llms/llama_index/portkey_llama_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler):
def __init__(
self,
api_key: str,
metadata: Optional[Dict[str, Any]] = {},
) -> None:
super().__init__(
event_starts_to_ignore=[
Expand All @@ -44,6 +45,7 @@ def __init__(
)

self.api_key = api_key
self.metadata = metadata

self.portkey_logger = Logger(api_key=api_key)

Expand Down Expand Up @@ -110,6 +112,7 @@ def on_event_start( # type: ignore
"trace_id": self.global_trace_id,
"request": request_payload,
"start_time": start_time,
"metadata": self.metadata,
}
self.event_map[span_id] = start_event_information

Expand Down Expand Up @@ -141,13 +144,14 @@ def on_event_end(
response_payload = payload

self.event_map[span_id]["response"] = response_payload

self.event_array.append(self.event_map[span_id])

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""

if trace_id == "index_construction":
self.global_trace_id = str(uuid4())
self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr]

self.main_span_id = ""

Expand Down

0 comments on commit 6693afe

Please sign in to comment.