From cd71d18e4ea7da023a317ed86fc081445b0595c2 Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Thu, 27 Jun 2024 20:59:22 +0530 Subject: [PATCH 1/8] feat: final span object ready before logging --- .../llama_index/portkey_llama_callback.py | 376 +++++++++++++++++- 1 file changed, 368 insertions(+), 8 deletions(-) diff --git a/portkey_ai/llms/llama_index/portkey_llama_callback.py b/portkey_ai/llms/llama_index/portkey_llama_callback.py index 0c353337..0ac6a8fc 100644 --- a/portkey_ai/llms/llama_index/portkey_llama_callback.py +++ b/portkey_ai/llms/llama_index/portkey_llama_callback.py @@ -1,7 +1,14 @@ +from enum import Enum +import json import time from typing import Any, Dict, List, Optional from portkey_ai.api_resources.apis.logger import Logger from datetime import datetime +from llama_index.core.callbacks.schema import (EventPayload, CBEvent, CBEventType, BASE_TRACE_EVENT) +from uuid import uuid4 +from llama_index.legacy.schema import ( + NodeRelationship +) try: from llama_index.core.callbacks.base_handler import ( @@ -19,14 +26,15 @@ class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler): startTimestamp: int = 0 endTimestamp: float = 0 + global_trace_id: str = "qwaszx12" def __init__( self, api_key: str, ) -> None: super().__init__( - event_starts_to_ignore=[], - event_ends_to_ignore=[], + event_starts_to_ignore=[CBEventType.CHUNKING, CBEventType.NODE_PARSING, CBEventType.SYNTHESIZE, EventPayload.EXCEPTION], + event_ends_to_ignore=[CBEventType.CHUNKING, CBEventType.NODE_PARSING, CBEventType.SYNTHESIZE, EventPayload.EXCEPTION], ) self.api_key = api_key @@ -47,6 +55,10 @@ def __init__( self.responseTime: int = 0 self.streamingMode: bool = False + self.event_map: Any = {} + self.event_array: List = [] + self.main_span_id: str = '' + if not api_key: raise ValueError("Please provide an API key to use PortkeyCallbackHandler") @@ -60,8 +72,44 @@ def on_event_start( # type: ignore[return] ) -> str: """Run when an event starts and return id of event.""" + print("on_event_start: event_type: ", event_type) + print("on_event_start: payload: ", payload) + # print("on_event_start: event_id: ", event_id) + # print("on_event_start: parent_id: ", parent_id) + # print("on_event_start: kwargs: ", kwargs) + span_id = str(event_id) + parent_span_id = parent_id + span_name = event_type + + if parent_id == 'root': + parent_span_id = self.main_span_id + if event_type == "llm": - self.llm_event_start(payload) + request_payload = self.llm_event_start(payload) + elif event_type == "embedding": + request_payload = self.embedding_event_start(payload) + elif event_type == "agent_step": + request_payload = self.agent_step_event_start(payload) + elif event_type == "function_call": + request_payload = self.function_call_event_start(payload) + elif event_type == "query": + request_payload = self.query_event_start(payload) + elif event_type == "retrieve": + request_payload = self.retrieve_event_start(payload) + elif event_type == "templating": + request_payload = self.templating_event_start(payload) + else : + request_payload = payload + + # TODO: Change logic so that we make the information object for the needed events only. + start_event_information = { + "span_id": span_id, + "parent_span_id": parent_span_id, + "span_name": span_name, + "trace_id": self.global_trace_id, + "request": request_payload + } + self.event_map[span_id] = start_event_information def on_event_end( self, @@ -71,12 +119,42 @@ def on_event_end( **kwargs: Any, ) -> None: """Run when an event ends.""" + span_id = event_id + + # print("on_event_end: event_type: ", event_type) + # print("on_event_end: payload: ", payload) if event_type == "llm": - self.llm_event_stop(payload, event_id) + response_payload = self.llm_event_end(payload, event_id) + elif event_type == "embedding": + response_payload = self.embedding_event_end(payload, event_id) + elif event_type == "agent_step": + response_payload = self.agent_step_event_end(payload, event_id) + elif event_type == "function_call": + print("FUNCTION CALL: END: ", payload) + response_payload = self.function_call_event_end(payload) + elif event_type == "query": + response_payload = self.query_event_end(payload, event_id) + elif event_type == "retrieve": + response_payload = self.retrieve_event_end(payload, event_id) + elif event_type == "templating": + response_payload = self.templating_event_end(payload, event_id) + else: + response_payload = payload + + self.event_map[span_id]['response'] = response_payload + self.event_array.append(self.event_map[span_id]) def start_trace(self, trace_id: Optional[str] = None) -> None: """Run when an overall trace is launched.""" + print("start_trace: trace_id: ", trace_id) + + if trace_id == 'index_construction': + self.global_trace_id = str(uuid4()) + print("start_trace: global_trace_id: ", self.global_trace_id) + + self.main_span_id = str(uuid4()) + self.startTimestamp = int(datetime.now().timestamp()) def end_trace( @@ -85,8 +163,20 @@ def end_trace( trace_map: Optional[Dict[str, List[str]]] = None, ) -> None: """Run when an overall trace is exited.""" + print("end_trace: trace_id: ", trace_id) + print("end_trace: trace_map: ", trace_map) + # print("event_map: ", self.event_map) + print("event_array: ", self.event_array) + ''' + If this is multi part call then we will put a logic here. + log and empty all the variables here. + We can make a function here as well to do this bit. + ''' + +# ----------------- EVENT Handlers ----------------- # def llm_event_start(self, payload: Any) -> None: + self.request = {} if "messages" in payload: chunks = payload.get("messages", {}) self.prompt_tokens = self._token_counter.estimate_tokens_in_messages(chunks) @@ -108,11 +198,12 @@ def llm_event_start(self, payload: Any) -> None: {"temperature": payload.get("serialized", {}).get("temperature", "")} ) - return None + return self.request - def llm_event_stop(self, payload: Any, event_id) -> None: + def llm_event_end(self, payload: Any, event_id) -> None: self.endTimestamp = float(datetime.now().timestamp()) responseTime = self.endTimestamp - self.startTimestamp + self.response = {} data = payload.get("response", {}) @@ -155,6 +246,275 @@ def llm_event_stop(self, payload: Any, event_id) -> None: "response": self.response, } ) - self.portkey_logger.log(log_object=self.log_object) + # self.portkey_logger.log(log_object=self.log_object) + + return self.response +# ------------------------------------------------------ # + def embedding_event_start(self, payload: Any) -> Any: + + # print("EMBEDDING: START: ", payload) + # print("EMBEDDING: START: KWARGS: ", **kwargs) + + self.request = {} + + # input + # model + # encoding_format ?? + + if "serialized" in payload: + self.request["method"] = "POST" + self.request["url"] = payload.get("serialized", {}).get("api_base", "embeddings") + self.request["provider"] = payload.get("serialized", {}).get("class_name", "") + self.request["headers"] = {} + self.request["body"] = {"model": payload.get("serialized", {}).get("model_name", "")} + + return self.request + + def embedding_event_end(self, payload: Any, event_id) -> Any: + # print("EMBEDDING: END: ", payload) + # print("EMBEDDING: END: KWARGS: ", **kwargs) + + if event_id in self.event_map: + event = self.event_map[event_id] + # event["request"]["body"]["input"] = payload.get("chunks", "") + event["request"]["body"]["input"] = "...INPUT..." + + self.response = {} + + self.endTimestamp = float(datetime.now().timestamp()) + responseTime = self.endTimestamp - self.startTimestamp + + chunk_str = str(payload.get("chunks", "")) + embd_str = str(payload.get("embeddings", "")) + + self.prompt_tokens = self._token_counter.get_string_tokens(chunk_str) + self.completion_tokens = self._token_counter.get_string_tokens(embd_str) + + self.response["status"] = 200 + self.response["body"] = { + "embedding": "...REDACTED...", + } + + self.response["body"].update({"created": int(time.time())}) + + self.response["body"].update( + { + "usage": { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.prompt_tokens + self.completion_tokens, + } + } + ) + self.response["headers"] = {} + self.response["time"] = int(responseTime * 1000) + # self.log_object.update( + # { + # "request": self.request, + # "response": self.response, + # } + # ) + + # self.portkey_logger.log(log_object=self.log_object) + + return self.response +# ------------------------------------------------------ # + def agent_step_event_start(self, payload: Any) -> Any: + data = json.dumps(self.serialize(payload)) + return data + + def agent_step_event_end(self, payload: Any, event_id) -> Any: + data = self.serialize(payload) + json_data = json.dumps(data) + result = self.transform_agent_step_end(data) + return result +# ------------------------------------------------------ # +# TODO: Transformation Function pending + def function_call_event_start(self, payload: Any) -> Any: + return payload + + def function_call_event_end(self, payload: Any) -> Any: + return payload +# ------------------------------------------------------ # + def query_event_start(self, payload: Any) -> Any: + data = json.dumps(self.serialize(payload)) + return data + + def query_event_end(self, payload: Any, event_id) -> Any: + data = self.serialize(payload) + json_data = json.dumps(data) + result = self.transform_query_end(data) + return result +# ------------------------------------------------------ # + def retrieve_event_start(self, payload: Any) -> Any: + data = json.dumps(self.serialize(payload)) + return data + + def retrieve_event_end(self, payload: Any, event_id) -> Any: + data = self.serialize(payload) + json_data = json.dumps(data) + result = self.transform_retrieve_end(data) + return result +# ------------------------------------------------------ # + def templating_event_start(self, payload: Any) -> Any: + data = self.serialize(payload) + json_data = json.dumps(data) + result = self.transform_templating_start(data) + return result + + def templating_event_end(self, payload: Any, event_id) -> Any: + result = self.transform_templating_end(event_id) + return result +# ------------------------------------------------------ # + + +# ----------------- EVENT Transformers ----------------- # + def transform_agent_step_end(self, data: Any) -> Any: + try: + output_data = { + "agent_chat_response": { + "response": data["response"]["response"], + "sources": [ + { + "content": source["content"], + "tool_name": source["tool_name"], + "raw_input": source["raw_input"], + "raw_output": { + "response": source["raw_output"]["response"], + "source_nodes": [ + { + "id": node["node"]["id_"], + "metadata": node["node"]["metadata"], + "text_excerpt": node["node"]["text"], + "score": node["score"] + } + for node in source["raw_output"]["source_nodes"] + ] + } + } + for source in data["response"]["sources"] + ] + } + } + return output_data + except Exception as e: + return data + + def transform_query_end(self, data: Any) -> Any: + try: + output_data = { + "response": { + "content": data["response"]["response"], + "source_nodes": [] + } + } + + for source_node in data["response"]["source_nodes"]: + node = source_node["node"] + metadata = node["metadata"] + text_excerpt = node["text"][node["start_char_idx"]:node["end_char_idx"]] + + output_data["response"]["source_nodes"].append({ + "id": node["id_"], + "metadata": { + "file_path": metadata["file_path"], + "file_name": metadata["file_name"], + "file_type": metadata["file_type"], + "file_size": metadata["file_size"], + "creation_date": metadata["creation_date"], + "last_modified_date": metadata["last_modified_date"] + }, + "text_excerpt": text_excerpt, + "score": source_node["score"] + }) + return output_data + except Exception as e: + return data + + def transform_retrieve_end(self, data: Any) -> Any: + try: + output_data = { + "nodes": [] + } + + for node_data in data["nodes"]: + node = node_data["node"] + metadata = node["metadata"] + text_excerpt = node["text"][node["start_char_idx"]:node["end_char_idx"]] + + relationships = {} + for relationship, details in node["relationships"].items(): + relationships[relationship.name] = { + "node_id": details["node_id"], + "node_type": "DOCUMENT" if relationship == NodeRelationship.SOURCE else "TEXT" + } + + output_data["nodes"].append({ + "id": node["id_"], + "metadata": { + "file_path": metadata["file_path"], + "file_name": metadata["file_name"], + "file_type": metadata["file_type"], + "file_size": metadata["file_size"], + "creation_date": metadata["creation_date"], + "last_modified_date": metadata["last_modified_date"] + }, + "relationships": relationships, + "text_excerpt": text_excerpt, + "score": node_data["score"] + }) + + return output_data + except Exception as e: + return data + + def transform_templating_start(self, data: Any) -> Any: + try: + output_data = { + "template": { + "system": data["template"].split('user:')[0].strip(), + "user": "Context information is below.\n---------------------\n{context_str}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {query_str}\nAnswer: ", + "assistant": "" + }, + "template_vars": { + "context_str": data["template_vars"]["context_str"], + "query_str": data["template_vars"]["query_str"] + }, + "system_prompt": data["system_prompt"], + "query_wrapper_prompt": data["query_wrapper_prompt"], + } + return output_data + except Exception as e: + return data + + def transform_templating_end(self, event_id) -> Any: + try: + request_data = self.event_map[event_id]["request"] + context_str = request_data["template_vars"]["context_str"] + query_str = request_data["template_vars"]["query_str"] + replace_str = request_data["template"]["user"] + output_data = { + "template":{ + "system": request_data["template"]["system"], + "user": replace_str.format(context_str=context_str, query_str=query_str), + "assistant": request_data["template"]["assistant"] + } + } + + return output_data + except Exception as e: + return "" + +# ----------------- HELPER FUNCTIONS ----------------- # - return None + def serialize(self, obj): + if isinstance(obj, Enum): + return obj.value + if hasattr(obj, '__dict__'): + return {key: self.serialize(value) for key, value in obj.__dict__.items()} + if isinstance(obj, list): + return [self.serialize(item) for item in obj] + if isinstance(obj, dict): + return {key: self.serialize(value) for key, value in obj.items()} + return obj + \ No newline at end of file From 060a4ab9d3454389dace7dcc88b808dd3004212c Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Fri, 28 Jun 2024 15:19:58 +0530 Subject: [PATCH 2/8] feat: handling function_call req and res --- .../llama_index/portkey_llama_callback.py | 311 +++++++++--------- 1 file changed, 162 insertions(+), 149 deletions(-) diff --git a/portkey_ai/llms/llama_index/portkey_llama_callback.py b/portkey_ai/llms/llama_index/portkey_llama_callback.py index 0ac6a8fc..bef1bd04 100644 --- a/portkey_ai/llms/llama_index/portkey_llama_callback.py +++ b/portkey_ai/llms/llama_index/portkey_llama_callback.py @@ -4,11 +4,11 @@ from typing import Any, Dict, List, Optional from portkey_ai.api_resources.apis.logger import Logger from datetime import datetime -from llama_index.core.callbacks.schema import (EventPayload, CBEvent, CBEventType, BASE_TRACE_EVENT) -from uuid import uuid4 -from llama_index.legacy.schema import ( - NodeRelationship +from llama_index.core.callbacks.schema import ( + CBEventType, ) +from uuid import uuid4 +from llama_index.legacy.schema import NodeRelationship try: from llama_index.core.callbacks.base_handler import ( @@ -26,15 +26,24 @@ class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler): startTimestamp: int = 0 endTimestamp: float = 0 - global_trace_id: str = "qwaszx12" def __init__( self, api_key: str, ) -> None: super().__init__( - event_starts_to_ignore=[CBEventType.CHUNKING, CBEventType.NODE_PARSING, CBEventType.SYNTHESIZE, EventPayload.EXCEPTION], - event_ends_to_ignore=[CBEventType.CHUNKING, CBEventType.NODE_PARSING, CBEventType.SYNTHESIZE, EventPayload.EXCEPTION], + event_starts_to_ignore=[ + CBEventType.CHUNKING, + CBEventType.NODE_PARSING, + CBEventType.SYNTHESIZE, + CBEventType.EXCEPTION, + ], + event_ends_to_ignore=[ + CBEventType.CHUNKING, + CBEventType.NODE_PARSING, + CBEventType.SYNTHESIZE, + CBEventType.EXCEPTION, + ], ) self.api_key = api_key @@ -57,12 +66,12 @@ def __init__( self.event_map: Any = {} self.event_array: List = [] - self.main_span_id: str = '' + self.main_span_id: str = "" if not api_key: raise ValueError("Please provide an API key to use PortkeyCallbackHandler") - def on_event_start( # type: ignore[return] + def on_event_start( # type: ignore self, event_type: Any, payload: Optional[Dict[str, Any]] = None, @@ -72,16 +81,11 @@ def on_event_start( # type: ignore[return] ) -> str: """Run when an event starts and return id of event.""" - print("on_event_start: event_type: ", event_type) - print("on_event_start: payload: ", payload) - # print("on_event_start: event_id: ", event_id) - # print("on_event_start: parent_id: ", parent_id) - # print("on_event_start: kwargs: ", kwargs) span_id = str(event_id) parent_span_id = parent_id span_name = event_type - if parent_id == 'root': + if parent_id == "root": parent_span_id = self.main_span_id if event_type == "llm": @@ -98,16 +102,15 @@ def on_event_start( # type: ignore[return] request_payload = self.retrieve_event_start(payload) elif event_type == "templating": request_payload = self.templating_event_start(payload) - else : + else: request_payload = payload - # TODO: Change logic so that we make the information object for the needed events only. start_event_information = { "span_id": span_id, "parent_span_id": parent_span_id, - "span_name": span_name, + "span_name": span_name.value, "trace_id": self.global_trace_id, - "request": request_payload + "request": request_payload, } self.event_map[span_id] = start_event_information @@ -121,9 +124,6 @@ def on_event_end( """Run when an event ends.""" span_id = event_id - # print("on_event_end: event_type: ", event_type) - # print("on_event_end: payload: ", payload) - if event_type == "llm": response_payload = self.llm_event_end(payload, event_id) elif event_type == "embedding": @@ -131,7 +131,6 @@ def on_event_end( elif event_type == "agent_step": response_payload = self.agent_step_event_end(payload, event_id) elif event_type == "function_call": - print("FUNCTION CALL: END: ", payload) response_payload = self.function_call_event_end(payload) elif event_type == "query": response_payload = self.query_event_end(payload, event_id) @@ -142,16 +141,14 @@ def on_event_end( else: response_payload = payload - self.event_map[span_id]['response'] = response_payload + self.event_map[span_id]["response"] = response_payload self.event_array.append(self.event_map[span_id]) def start_trace(self, trace_id: Optional[str] = None) -> None: """Run when an overall trace is launched.""" - print("start_trace: trace_id: ", trace_id) - if trace_id == 'index_construction': + if trace_id == "index_construction": self.global_trace_id = str(uuid4()) - print("start_trace: global_trace_id: ", self.global_trace_id) self.main_span_id = str(uuid4()) @@ -163,19 +160,14 @@ def end_trace( trace_map: Optional[Dict[str, List[str]]] = None, ) -> None: """Run when an overall trace is exited.""" - print("end_trace: trace_id: ", trace_id) - print("end_trace: trace_map: ", trace_map) - # print("event_map: ", self.event_map) - print("event_array: ", self.event_array) - ''' - If this is multi part call then we will put a logic here. - log and empty all the variables here. - We can make a function here as well to do this bit. - ''' - - -# ----------------- EVENT Handlers ----------------- # - def llm_event_start(self, payload: Any) -> None: + + self.log_object = {"data": self.event_array} + + self.portkey_logger.log(log_object=self.log_object) + self.event_array = [] + + # ------------------- EVENT Handlers ------------------- # + def llm_event_start(self, payload: Any) -> Any: self.request = {} if "messages" in payload: chunks = payload.get("messages", {}) @@ -200,7 +192,7 @@ def llm_event_start(self, payload: Any) -> None: return self.request - def llm_event_end(self, payload: Any, event_id) -> None: + def llm_event_end(self, payload: Any, event_id) -> Any: self.endTimestamp = float(datetime.now().timestamp()) responseTime = self.endTimestamp - self.startTimestamp self.response = {} @@ -240,44 +232,31 @@ def llm_event_end(self, payload: Any, event_id) -> None: self.response["headers"] = {} self.response["streamingMode"] = self.streamingMode - self.log_object.update( - { - "request": self.request, - "response": self.response, - } - ) - # self.portkey_logger.log(log_object=self.log_object) - return self.response -# ------------------------------------------------------ # - def embedding_event_start(self, payload: Any) -> Any: - # print("EMBEDDING: START: ", payload) - # print("EMBEDDING: START: KWARGS: ", **kwargs) - + # ------------------------------------------------------ # + def embedding_event_start(self, payload: Any) -> Any: self.request = {} - - # input - # model - # encoding_format ?? - if "serialized" in payload: self.request["method"] = "POST" - self.request["url"] = payload.get("serialized", {}).get("api_base", "embeddings") - self.request["provider"] = payload.get("serialized", {}).get("class_name", "") + self.request["url"] = payload.get("serialized", {}).get( + "api_base", "embeddings" + ) + self.request["provider"] = payload.get("serialized", {}).get( + "class_name", "" + ) self.request["headers"] = {} - self.request["body"] = {"model": payload.get("serialized", {}).get("model_name", "")} - + self.request["body"] = { + "model": payload.get("serialized", {}).get("model_name", "") + } + return self.request def embedding_event_end(self, payload: Any, event_id) -> Any: - # print("EMBEDDING: END: ", payload) - # print("EMBEDDING: END: KWARGS: ", **kwargs) - if event_id in self.event_map: event = self.event_map[event_id] - # event["request"]["body"]["input"] = payload.get("chunks", "") - event["request"]["body"]["input"] = "...INPUT..." + event["request"]["body"]["input"] = payload.get("chunks", "") + # event["request"]["body"]["input"] = "...INPUT..." self.response = {} @@ -286,7 +265,7 @@ def embedding_event_end(self, payload: Any, event_id) -> Any: chunk_str = str(payload.get("chunks", "")) embd_str = str(payload.get("embeddings", "")) - + self.prompt_tokens = self._token_counter.get_string_tokens(chunk_str) self.completion_tokens = self._token_counter.get_string_tokens(embd_str) @@ -308,67 +287,67 @@ def embedding_event_end(self, payload: Any, event_id) -> Any: ) self.response["headers"] = {} self.response["time"] = int(responseTime * 1000) - # self.log_object.update( - # { - # "request": self.request, - # "response": self.response, - # } - # ) - - # self.portkey_logger.log(log_object=self.log_object) return self.response -# ------------------------------------------------------ # + + # ------------------------------------------------------ # def agent_step_event_start(self, payload: Any) -> Any: data = json.dumps(self.serialize(payload)) return data def agent_step_event_end(self, payload: Any, event_id) -> Any: data = self.serialize(payload) - json_data = json.dumps(data) + json.dumps(data) result = self.transform_agent_step_end(data) return result -# ------------------------------------------------------ # -# TODO: Transformation Function pending + + # ------------------------------------------------------ # def function_call_event_start(self, payload: Any) -> Any: - return payload - + result = self.transform_function_call_start(payload) + return result + def function_call_event_end(self, payload: Any) -> Any: - return payload -# ------------------------------------------------------ # + data = self.serialize(payload) + json.dumps(data) + result = self.transform_function_call_end(data) + return result + + # ------------------------------------------------------ # def query_event_start(self, payload: Any) -> Any: data = json.dumps(self.serialize(payload)) return data - + def query_event_end(self, payload: Any, event_id) -> Any: data = self.serialize(payload) - json_data = json.dumps(data) + json.dumps(data) result = self.transform_query_end(data) return result -# ------------------------------------------------------ # + + # ------------------------------------------------------ # def retrieve_event_start(self, payload: Any) -> Any: data = json.dumps(self.serialize(payload)) return data - + def retrieve_event_end(self, payload: Any, event_id) -> Any: data = self.serialize(payload) - json_data = json.dumps(data) + json.dumps(data) result = self.transform_retrieve_end(data) return result -# ------------------------------------------------------ # + + # ------------------------------------------------------ # def templating_event_start(self, payload: Any) -> Any: data = self.serialize(payload) - json_data = json.dumps(data) + json.dumps(data) result = self.transform_templating_start(data) return result - + def templating_event_end(self, payload: Any, event_id) -> Any: result = self.transform_templating_end(event_id) return result -# ------------------------------------------------------ # + # ------------------------------------------------------ # -# ----------------- EVENT Transformers ----------------- # + # ----------------- EVENT Transformers ----------------- # def transform_agent_step_end(self, data: Any) -> Any: try: output_data = { @@ -386,18 +365,18 @@ def transform_agent_step_end(self, data: Any) -> Any: "id": node["node"]["id_"], "metadata": node["node"]["metadata"], "text_excerpt": node["node"]["text"], - "score": node["score"] + "score": node["score"], } for node in source["raw_output"]["source_nodes"] - ] - } + ], + }, } for source in data["response"]["sources"] - ] + ], } } return output_data - except Exception as e: + except Exception: return data def transform_query_end(self, data: Any) -> Any: @@ -405,86 +384,94 @@ def transform_query_end(self, data: Any) -> Any: output_data = { "response": { "content": data["response"]["response"], - "source_nodes": [] + "source_nodes": [], } } for source_node in data["response"]["source_nodes"]: node = source_node["node"] metadata = node["metadata"] - text_excerpt = node["text"][node["start_char_idx"]:node["end_char_idx"]] - - output_data["response"]["source_nodes"].append({ - "id": node["id_"], - "metadata": { - "file_path": metadata["file_path"], - "file_name": metadata["file_name"], - "file_type": metadata["file_type"], - "file_size": metadata["file_size"], - "creation_date": metadata["creation_date"], - "last_modified_date": metadata["last_modified_date"] - }, - "text_excerpt": text_excerpt, - "score": source_node["score"] - }) + text_excerpt = node["text"][ + node["start_char_idx"] : node["end_char_idx"] + ] + + output_data["response"]["source_nodes"].append( + { + "id": node["id_"], + "metadata": { + "file_path": metadata["file_path"], + "file_name": metadata["file_name"], + "file_type": metadata["file_type"], + "file_size": metadata["file_size"], + "creation_date": metadata["creation_date"], + "last_modified_date": metadata["last_modified_date"], + }, + "text_excerpt": text_excerpt, + "score": source_node["score"], + } + ) return output_data - except Exception as e: + except Exception: return data def transform_retrieve_end(self, data: Any) -> Any: try: - output_data = { - "nodes": [] - } + output_data = {"nodes": []} # type: ignore[var-annotated] for node_data in data["nodes"]: node = node_data["node"] metadata = node["metadata"] - text_excerpt = node["text"][node["start_char_idx"]:node["end_char_idx"]] + text_excerpt = node["text"][ + node["start_char_idx"] : node["end_char_idx"] + ] relationships = {} for relationship, details in node["relationships"].items(): relationships[relationship.name] = { "node_id": details["node_id"], - "node_type": "DOCUMENT" if relationship == NodeRelationship.SOURCE else "TEXT" + "node_type": "DOCUMENT" + if relationship == NodeRelationship.SOURCE + else "TEXT", } - output_data["nodes"].append({ - "id": node["id_"], - "metadata": { - "file_path": metadata["file_path"], - "file_name": metadata["file_name"], - "file_type": metadata["file_type"], - "file_size": metadata["file_size"], - "creation_date": metadata["creation_date"], - "last_modified_date": metadata["last_modified_date"] - }, - "relationships": relationships, - "text_excerpt": text_excerpt, - "score": node_data["score"] - }) + output_data["nodes"].append( + { + "id": node["id_"], + "metadata": { + "file_path": metadata["file_path"], + "file_name": metadata["file_name"], + "file_type": metadata["file_type"], + "file_size": metadata["file_size"], + "creation_date": metadata["creation_date"], + "last_modified_date": metadata["last_modified_date"], + }, + "relationships": relationships, + "text_excerpt": text_excerpt, + "score": node_data["score"], + } + ) return output_data - except Exception as e: + except Exception: return data def transform_templating_start(self, data: Any) -> Any: try: output_data = { "template": { - "system": data["template"].split('user:')[0].strip(), - "user": "Context information is below.\n---------------------\n{context_str}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {query_str}\nAnswer: ", - "assistant": "" + "system": data["template"].split("user:")[0].strip(), + "user": "Context information is below.\n---------------------\n{context_str}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {query_str}\nAnswer: ", # noqa: E501 + "assistant": "", }, "template_vars": { "context_str": data["template_vars"]["context_str"], - "query_str": data["template_vars"]["query_str"] + "query_str": data["template_vars"]["query_str"], }, "system_prompt": data["system_prompt"], "query_wrapper_prompt": data["query_wrapper_prompt"], } return output_data - except Exception as e: + except Exception: return data def transform_templating_end(self, event_id) -> Any: @@ -494,27 +481,53 @@ def transform_templating_end(self, event_id) -> Any: query_str = request_data["template_vars"]["query_str"] replace_str = request_data["template"]["user"] output_data = { - "template":{ + "template": { "system": request_data["template"]["system"], - "user": replace_str.format(context_str=context_str, query_str=query_str), - "assistant": request_data["template"]["assistant"] + "user": replace_str.format( + context_str=context_str, query_str=query_str + ), + "assistant": request_data["template"]["assistant"], } } return output_data - except Exception as e: + except Exception: return "" -# ----------------- HELPER FUNCTIONS ----------------- # + def transform_function_call_start(self, data: Any) -> Any: + try: + tool_meta = data.get("tool") + output_data = { + "function_call": data.get("function_call", ""), + "tool": { + "description": tool_meta.description + if tool_meta.description + else "", + "name": tool_meta.name if tool_meta.name else "", + }, + } + return output_data + except Exception: + return data + + def transform_function_call_end(self, data: Any) -> Any: + try: + output_data = {"function_call_response": data["function_call_response"]} + return output_data + except Exception: + return data + + # ----------------- HELPER FUNCTIONS ------------------- # def serialize(self, obj): if isinstance(obj, Enum): return obj.value - if hasattr(obj, '__dict__'): + if hasattr(obj, "__dict__"): return {key: self.serialize(value) for key, value in obj.__dict__.items()} if isinstance(obj, list): return [self.serialize(item) for item in obj] if isinstance(obj, dict): return {key: self.serialize(value) for key, value in obj.items()} + if isinstance(obj, tuple): + return tuple(self.serialize(item) for item in obj) return obj - \ No newline at end of file From b9cd0787d4cc38252ccb473ea6c45176f67ae93b Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Tue, 2 Jul 2024 19:37:29 +0530 Subject: [PATCH 3/8] feat: langchain callback handler for otel --- .../langchain/portkey_langchain_callback.py | 857 ++++++++++++++++-- 1 file changed, 775 insertions(+), 82 deletions(-) diff --git a/portkey_ai/llms/langchain/portkey_langchain_callback.py b/portkey_ai/llms/langchain/portkey_langchain_callback.py index ee9c7faf..b37f373e 100644 --- a/portkey_ai/llms/langchain/portkey_langchain_callback.py +++ b/portkey_ai/llms/langchain/portkey_langchain_callback.py @@ -1,10 +1,13 @@ -from datetime import datetime +from enum import Enum +import json import time from typing import Any, Dict, List, Optional +from uuid import UUID, uuid4 from portkey_ai.api_resources.apis.logger import Logger +import re try: - from langchain_core.callbacks import BaseCallbackHandler + from langchain_core.callbacks.base import BaseCallbackHandler except ImportError: raise ImportError("Please pip install langchain-core to use PortkeyLangchain") @@ -34,137 +37,827 @@ def __init__( self.streamingMode: bool = False + self.global_trace_id: str = "" + self.event_map: Any = {} + self.event_array: List = [] + self.main_span_id: str = "" + if not api_key: raise ValueError("Please provide an API key to use PortkeyCallbackHandler") def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: - for prompt in prompts: - messages = prompt.split("\n") - for message in messages: - role, content = message.split(":", 1) - self.prompt_records.append( - {"role": role.lower(), "content": content.strip()} - ) + """Run when LLM starts running. + + **ATTENTION**: This method is called for non-chat models (regular LLMs). If + you're implementing a handler for a chat model, + you should use on_chat_model_start instead. + """ + # print("on_llm_start") + # print("on_llm_start: run_id: ", run_id) + # print("on_llm_start: parent_run_id: ", parent_run_id) + + request_payload = self.on_llm_start_transformer( + serialized, prompts, kwargs=kwargs + ) + info_obj = self.start_event_information( + run_id, + parent_run_id, + "llm_start", + self.global_trace_id, + request_payload, + ) + self.event_map["llm_start_" + str(run_id)] = info_obj + pass + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[Any]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running. - self.startTimestamp = float(datetime.now().timestamp()) + **ATTENTION**: This method is called for chat models. If you're implementing + a handler for a non-chat model, you should use on_llm_start instead. + """ + # NotImplementedError is thrown intentionally + # Callback handler will fall back to on_llm_start if this is exception is thrown - self.streamingMode = kwargs.get("invocation_params", False).get("stream", False) + # print("on_chat_model_start") + # print("on_chat_model_start: serialized: ", serialized) + # print("on_chat_model_start: messages: ", messages) + # print("on_chat_model_start: kwargs: ", kwargs) + # print("on_chat_model_start: run_id: ", run_id) + # print("on_chat_model_start: parent_run_id: ", parent_run_id) - self.request["method"] = "POST" - self.request["url"] = serialized.get("kwargs", "").get( - "base_url", "chat/completions" + request_payload = self.on_chat_model_start_transformer( + serialized, messages, kwargs=kwargs ) - self.request["provider"] = serialized["id"][2] - self.request["headers"] = serialized.get("kwargs", {}).get( - "default_headers", {} + info_obj = self.start_event_information( + run_id, + parent_run_id, + "chat_model_start", + self.global_trace_id, + request_payload, ) - self.request["headers"].update({"provider": serialized["id"][2]}) - self.request["body"] = {"messages": self.prompt_records} - self.request["body"].update({**kwargs.get("invocation_params", {})}) + self.event_map["chat_model_start_" + str(run_id)] = info_obj + self.event_array.append(self.event_map["chat_model_start_" + str(run_id)]) + + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `on_chat_model_start`" + ) + + def on_llm_end( + self, + response: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM ends running.""" + # print("on_llm_end: ") + # print(f"on_llm_end: run_id: {run_id}") + # print(f"on_llm_end: parent_run_id: {parent_run_id}") + + response_payload = self.on_llm_end_transformer(response, kwargs=kwargs) + self.event_map["llm_start_" + str(run_id)]["response"] = response_payload + self.event_array.append(self.event_map["llm_start_" + str(run_id)]) + pass def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when chain starts running.""" - def on_llm_end(self, response: Any, **kwargs: Any) -> None: - self.endTimestamp = float(datetime.now().timestamp()) - responseTime = self.endTimestamp - self.startTimestamp + # print("on_chain_start: ") + # print(f"on_chain_start: run_id: {run_id}") + # print(f"on_chain_start: parent_run_id: {parent_run_id}") + # print("on_chain_start: serialized: ", serialized) + # print("on_chain_start: inputs: ", inputs) + # print("on_chain_start: kwargs: ", kwargs) - usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr] + if parent_run_id is None: + self.global_trace_id = str(uuid4()) + self.main_span_id = str(uuid4()) - self.response["status"] = ( - 200 if self.responseStatus == 0 else self.responseStatus + parent_span_id = ( + self.main_span_id if parent_run_id is None else str(parent_run_id) ) - self.response["body"] = { - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": response.generations[0][0].text, - }, - "logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501 - "finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501 - } - ] - } - self.response["body"].update({"usage": usage}) - self.response["body"].update({"id": str(kwargs.get("run_id", ""))}) - self.response["body"].update({"created": int(time.time())}) - self.response["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501 - self.response["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501 - self.response["time"] = int(responseTime * 1000) - self.response["headers"] = {} - self.response["streamingMode"] = self.streamingMode - - self.log_object.update( - { - "request": self.request, - "response": self.response, - } + + request_payload = self.on_chain_start_transformer( + serialized, inputs, kwargs=kwargs + ) + info_obj = self.start_event_information( + run_id, + parent_span_id, + "chain_start", + self.global_trace_id, + request_payload, ) - self.portkey_logger.log(log_object=self.log_object) + self.event_map["chain_start_" + str(run_id)] = info_obj + # print( + # "on_chain_start: eventMap: ", self.event_map["chain_start_" + str(run_id)] + # ) + pass def on_chain_end( self, outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when chain ends running.""" - pass + # print("on_chain_end: ") + # print(f"on_chain_end: run_id: {run_id}") + # print(f"on_chain_end: parent_run_id: {parent_run_id}") - def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: - self.responseBody = error - self.responseStatus = error.status_code # type: ignore[attr-defined] - """Do nothing.""" - pass + response_payload = self.on_chain_end_transformer(outputs) + + self.event_map["chain_start_" + str(run_id)]["response"] = response_payload + self.event_array.append(self.event_map["chain_start_" + str(run_id)]) + + if parent_run_id is None: + # print("END OF THE ENTIRE CHAIN") + print("FINAL EVENT ARRAY: ", self.event_array) + self.log_object = {"data": self.event_array} + self.portkey_logger.log(log_object=self.log_object) + + self.event_array = [] + self.event_map = {} - def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: - self.responseBody = error - self.responseStatus = error.status_code # type: ignore[attr-defined] - """Do nothing.""" pass - def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: - self.responseBody = error - self.responseStatus = error.status_code # type: ignore[attr-defined] + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when tool starts running.""" + print("on_tool_start: ") + # print(f"on_tool_start: run_id: {run_id}") + # print(f"on_tool_start: parent_run_id: {parent_run_id}") + request_payload = self.on_tool_start_transformer(serialized, input_str, inputs) + info_obj = self.start_event_information( + run_id, + parent_run_id, + "tool_start", + self.global_trace_id, + request_payload, + ) + self.event_map["tool_start_" + str(run_id)] = info_obj pass - def on_text(self, text: str, **kwargs: Any) -> None: + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool ends running.""" + # print("on_tool_end: ") + # print(f"on_tool_end: run_id: {run_id}") + # print(f"on_tool_end: parent_run_id: {parent_run_id}") + response_payload = self.on_tool_end_transformer(output) + self.event_map["tool_start_" + str(run_id)]["response"] = response_payload + self.event_array.append(self.event_map["tool_start_" + str(run_id)]) pass - def on_agent_finish(self, finish: Any, **kwargs: Any) -> None: + def on_text( # Do we need to log this or not? This is just formatting of the text + self, + text: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on arbitrary text.""" + # print("on_text: ") + # print(f"on_text: run_id: {run_id}") + # print(f"on_text: parent_run_id: {parent_run_id}") + + parent_span_id = ( + self.main_span_id if parent_run_id is None else str(parent_run_id) + ) + request_payload = self.on_text_transformer(text) + info_obj = self.start_event_information( + run_id, parent_span_id, "text", self.global_trace_id, request_payload + ) + self.event_map["text_" + str(run_id)] = info_obj + self.event_array.append(self.event_map["text_" + str(run_id)]) pass - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - self.streamingMode = True - """Do nothing.""" + def on_agent_action( + self, + action: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on agent action.""" + # print("on_agent_action: ") + # print(f"on_agent_action: run_id: {run_id}") + # print(f"on_agent_action: parent_run_id: {parent_run_id}") + + parent_span_id = ( + self.main_span_id if parent_run_id is None else str(parent_run_id) + ) + request_payload = self.on_agent_action_transformer(action) + info_obj = self.start_event_information( + run_id, + parent_span_id, + "agent_action", + self.global_trace_id, + request_payload, + ) + self.event_map["agent_action_" + str(run_id)] = info_obj pass - def on_tool_start( + def on_agent_finish( self, - serialized: Dict[str, Any], - input_str: str, + finish: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: + """Run on agent end.""" + # print("on_agent_finish: ") + # print(f"on_agent_finish: run_id: {run_id}") + # print(f"on_agent_finish: parent_run_id: {parent_run_id}") + + (self.main_span_id if parent_run_id is None else str(parent_run_id)) + response_payload = self.on_agent_finish_transformer(finish) + self.event_map["agent_action_" + str(run_id)]["response"] = response_payload + self.event_array.append(self.event_map["agent_action_" + str(run_id)]) pass - def on_agent_action(self, action: Any, **kwargs: Any) -> Any: - """Do nothing.""" + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run on retriever start.""" + # print("on_retriever_start: ") + # print(f"on_retriever_start: run_id: {run_id}") + # print(f"on_retriever_start: parent_run_id: {parent_run_id}") pass - def on_tool_end( + def on_retriever_end( self, - output: Any, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, + documents: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: + """Run on retriever end.""" + # print("on_retriever_end: ") + # print(f"on_retriever_end: run_id: {run_id}") + # print(f"on_retriever_end: parent_run_id: {parent_run_id}") pass + + ''' + + # def on_llm_new_token( + # self, + # token: str, + # *, + # chunk: Optional[Union[Any, Any]] = None, + # run_id: UUID, + # parent_run_id: Optional[UUID] = None, + # tags: Optional[List[str]] = None, + # **kwargs: Any, + # ) -> None: + # """Run on new LLM token. Only available when streaming is enabled.""" + # print("on_llm_new_token: ") + # print(f"on_llm_new_token: run_id: {run_id}") + # print(f"on_llm_new_token: parent_run_id: {parent_run_id}") + # pass + + # def on_retriever_error( + # self, + # error: BaseException, + # *, + # run_id: UUID, + # parent_run_id: Optional[UUID] = None, + # tags: Optional[List[str]] = None, + # **kwargs: Any, + # ) -> None: + # """Run on retriever error.""" + # print("on_retriever_error: ") + # print(f"on_retriever_error: run_id: {run_id}") + # print(f"on_retriever_error: parent_run_id: {parent_run_id}") + # pass + + # def on_llm_error( + # self, + # error: BaseException, + # *, + # run_id: UUID, + # parent_run_id: Optional[UUID] = None, + # tags: Optional[List[str]] = None, + # **kwargs: Any, + # ) -> None: + # """Run when LLM errors. + + # Args: + # error: The error that occurred. + # kwargs (Any): Additional keyword arguments. + # - response (Any): The response which was generated before + # the error occurred. + # """ + # print("on_llm_error: ") + # print(f"on_llm_error: run_id: {run_id}") + # print(f"on_llm_error: parent_run_id: {parent_run_id}") + # pass + + # def on_retry( + # self, + # retry_state: Any, + # *, + # run_id: UUID, + # parent_run_id: Optional[UUID] = None, + # **kwargs: Any, + # ) -> Any: + # """Run on a retry event.""" + # print("on_retry: ") + # # print(f"on_retry: run_id: {run_id}") + # # print(f"on_retry: parent_run_id: {parent_run_id}") + # pass + + # def on_tool_error( + # self, + # error: BaseException, + # *, + # run_id: UUID, + # parent_run_id: Optional[UUID] = None, + # tags: Optional[List[str]] = None, + # **kwargs: Any, + # ) -> None: + # """Run when tool errors.""" + # print("on_tool_error: ") + # print(f"on_tool_error: run_id: {run_id}") + # print(f"on_tool_error: parent_run_id: {parent_run_id}") + # pass + + # def on_chain_error( + # self, + # error: BaseException, + # *, + # run_id: UUID, + # parent_run_id: Optional[UUID] = None, + # tags: Optional[List[str]] = None, + # **kwargs: Any, + # ) -> None: + # """Run when chain errors.""" + # print("on_chain_error: ") + # print(f"on_chain_error: run_id: {run_id}") + # print(f"on_chain_error: parent_run_id: {parent_run_id}") + # pass + + ''' + + # -------------- Helpers ------------------------------------------ + def start_event_information( + self, span_id, parent_span_id, span_name, trace_id, request_payload + ): + return { + "span_id": str(span_id), + "parent_span_id": str(parent_span_id), + "span_name": span_name, + "trace_id": trace_id, + "request": request_payload, + } + + def serialize(self, obj): + if isinstance(obj, Enum): + return obj.value + if hasattr(obj, "__dict__"): + return {key: self.serialize(value) for key, value in obj.__dict__.items()} + if isinstance(obj, list): + return [self.serialize(item) for item in obj] + if isinstance(obj, dict): + return {key: self.serialize(value) for key, value in obj.items()} + if isinstance(obj, tuple): + return tuple(self.serialize(item) for item in obj) + return obj + + def extract_tools(self, content: str) -> List[Dict[str, Any]]: + tools_pattern = re.compile(r"(\w+)\((.*?)\) -> (.*?) - (.+)") + tools_matches = tools_pattern.findall(content) + tools = [] + for match in tools_matches: + tool_name, params, return_type, description = match + param_pattern = re.compile(r"(\w+): (.+?)(?:,|$)") + param_matches = param_pattern.findall(params) + {param: param_type for param, param_type in param_matches} + tools.append( + { + "name": tool_name, + "description": description, + "return_type": return_type, + } + ) + return tools + + def extract_response_format(self, content: str) -> List[str]: + response_format_pattern = re.compile( + r"Use the following format:\s*(.*?)\s*Begin!", re.DOTALL + ) + response_format_match = response_format_pattern.search(content) + if response_format_match: + response_format_content = response_format_match.group(1) + response_format_lines = [ + line.strip() + for line in response_format_content.split("\n") + if line.strip() + ] + return response_format_lines + return [] + + # ---------------------------------------------------------------------------- + + # ------ Event Transformers ------ + + def on_llm_start_transformer(self, serialized, prompts, kwargs): + # print("on_llm_start_transformer: serialized: ", serialized) + # print("on_llm_start_transformer: prompts: ", prompts) + # print("on_llm_start_transformer: kwargs: ", kwargs) + + try: + result = {"messages": []} + for entry in prompts: + role, content = entry.split(": ", 1) + tools = self.extract_tools(content) + response_format = self.extract_response_format(content) + example_question_pattern = re.compile(r"Question: (.+?)\n") + example_question_match = example_question_pattern.search(content) + example_question = ( + example_question_match.group(1) if example_question_match else "" + ) + content_before_format = content.split("Use the following format:")[ + 0 + ].strip() + input_data = { + "role": role, + "content": content_before_format, + "tools": tools, + "response_format": {"structure": response_format}, + "example_question": example_question, + "example_thought": "", + } + result["messages"].append(input_data) + + request = {} + + # print("RESPONSE MESSAGES: ", result["messages"]) + + # startTimestamp = float(datetime.now().timestamp()) + + # streamingMode = kwargs.get("invocation_params", False).get("stream", False) + + request["method"] = "POST" + request["url"] = serialized.get("kwargs", "").get( + "base_url", "chat/completions" + ) + request["provider"] = serialized["id"][2] + request["headers"] = serialized.get("kwargs", {}).get("default_headers", {}) + request["headers"].update({"provider": serialized["id"][2]}) + request["body"] = {"messages": result["messages"]} + request["body"].update(kwargs.get("invocation_params", {})) + return request + except Exception as e: + print("on_llm_start_transformer: Error: ", e) + return { + "serialized": serialized, + "prompts": prompts, + "invocation_params": kwargs, + } + + def on_llm_end_transformer(self, response, kwargs): + try: + response_obj = {} + # self.endTimestamp = float(datetime.now().timestamp()) + # responseTime = self.endTimestamp - self.startTimestamp + usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr] + + # self.response["status"] = ( + # 200 if self.responseStatus == 0 else self.responseStatus + # ) + + response_obj["body"] = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response.generations[0][0].text, + }, + "logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501 + "finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501 + } + ] + } + response_obj["body"].update({"usage": usage}) + response_obj["body"].update({"id": str(kwargs.get("run_id", ""))}) + response_obj["body"].update({"created": int(time.time())}) + response_obj["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501 + response_obj["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501 + # response["time"] = int(responseTime * 1000) + response_obj["headers"] = {} + # response["streamingMode"] = streamingMode + return response_obj + except Exception as e: + print("on_llm_end_transformer: Error: ", e) + return {"response": response, "kwargs": kwargs} + + def on_chain_start_transformer(self, serialized, input, kwargs): + try: + # print("on_chain_start_transformer: serialized: ", serialized) + # print("on_chain_start_transformer: input: ", input) + # print("on_chain_start_transformer: kwargs: ", kwargs) + name = kwargs["name"] + return { + "name": name, + "input": json.dumps(input), + } + except Exception as e: + print("Error in on_chain_start_transformer:", str(e)) + return {"serialized": serialized, "input": input, "kwargs": kwargs} + + def on_chain_end_transformer(self, output): + try: + structured_data = {} + + # Define patterns and corresponding keys + patterns = { + r"Action:\s*(.*)": "action", + r"Action Input:\s*(.*)": "action_input", + r"Final Answer:\s*(.*)": "final_answer", + r"Answer:\s*(.*)": "answer", + } + + for key, value in output.items(): + for pattern, result_key in patterns.items(): + matches = re.findall(pattern, value) + if matches: + structured_data[result_key] = matches[0].strip() + else: + structured_data[key] = value + + return {"output": structured_data} + + except Exception as e: + print("Error in on_chain_end_transformer:", str(e)) + return {"output": output} + + def on_text_transformer(self, text): + try: + # print("on_text_transformer: text: ", text) + return {"text": text} + except Exception as e: + print("Error in on_text_transformer:", str(e)) + return {"text": text} + + def on_chat_model_start_transformer(self, serialized, messages, kwargs): + try: + model = serialized["id"][-1] + invocation_params = kwargs["invocation_params"] + input_data = self.serialize(messages) + message_obj = input_data[0][0] + return { + "model": model, + "invocation_params": invocation_params, + "messages": message_obj, + } + except Exception as e: + print("Error in on_chat_model_start_transformer:", str(e)) + return {"serialized": serialized, "messages": message_obj, "kwargs": kwargs} + + def on_agent_action_transformer(self, action): + try: + action = self.serialize(action) + tool = action["tool"] + tool_input = action["tool_input"] + log = action["log"] + return {"tool": tool, "tool_input": tool_input, "log": log} + except Exception as e: + print("Error in on_agent_action_transformer:", str(e)) + return {"action": action} + + def on_agent_finish_transformer(self, finish): + try: + # print("on_agent_finish_transformer: finish: ", finish) + finish = self.serialize(finish) + return_values = finish["return_values"] + log = finish["log"] + return {"return_values": return_values, "log": log} + except Exception as e: + print("Error in on_agent_finish_transformer:", str(e)) + return {"finish": finish} + + def on_tool_start_transformer(self, serialized, input_str, inputs): + try: + return {"serialized": serialized, "input_str": input_str, "inputs": inputs} + except Exception as e: + print("Error in on_tool_start_transformer: ", str(e)) + return {"serialized": serialized, "input_str": input_str, "inputs": inputs} + + def on_tool_end_transformer(self, output): + try: + return {"output": output} + except Exception as e: + print("Error in on_tool_end_transformer: ", str(e)) + return {"output": output} + + +# flake8: noqa: E501 +''' + + # def on_llm_start( + # self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + # ) -> None: + # print("on_llm_start: ") + # # for prompt in prompts: + # # messages = prompt.split("\n") + # # for message in messages: + # # role, content = message.split(":", 1) + # # self.prompt_records.append( + # # {"role": role.lower(), "content": content.strip()} + # # ) + + # # self.startTimestamp = float(datetime.now().timestamp()) + + # # self.streamingMode = kwargs.get("invocation_params", False).get("stream", False) # noqa: E501 + + # # self.request["method"] = "POST" + # # self.request["url"] = serialized.get("kwargs", "").get( + # # "base_url", "chat/completions" + # # ) + # # self.request["provider"] = serialized["id"][2] + # # self.request["headers"] = serialized.get("kwargs", {}).get( + # # "default_headers", {} + # # ) + # # self.request["headers"].update({"provider": serialized["id"][2]}) + # # self.request["body"] = {"messages": self.prompt_records} + # # self.request["body"].update({**kwargs.get("invocation_params", {})}) + # pass + + # def on_chain_start( + # self, + # serialized: Dict[str, Any], + # inputs: Dict[str, Any], + # **kwargs: Any, + # ) -> None: + # """Run when chain starts running.""" + # pass + + # def on_llm_end(self, response: Any, **kwargs: Any) -> None: + # # self.endTimestamp = float(datetime.now().timestamp()) + # # responseTime = self.endTimestamp - self.startTimestamp + + # # usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr] # noqa: E501 + + # # self.response["status"] = ( + # # 200 if self.responseStatus == 0 else self.responseStatus + # # ) + # # self.response["body"] = { + # # "choices": [ + # # { + # # "index": 0, + # # "message": { + # # "role": "assistant", + # # "content": response.generations[0][0].text, + # # }, + # # "logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501 + # # "finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501 + # # } + # # ] + # # } + # # self.response["body"].update({"usage": usage}) + # # self.response["body"].update({"id": str(kwargs.get("run_id", ""))}) + # # self.response["body"].update({"created": int(time.time())}) + # # self.response["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501 + # # self.response["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501 + # # self.response["time"] = int(responseTime * 1000) + # # self.response["headers"] = {} + # # self.response["streamingMode"] = self.streamingMode + + # # self.log_object.update( + # # { + # # "request": self.request, + # # "response": self.response, + # # } + # # ) + + # # self.portkey_logger.log(log_object=self.log_object) + # pass + + # def on_chain_end( + # self, + # outputs: Dict[str, Any], + # **kwargs: Any, + # ) -> None: + # """Run when chain ends running.""" + # pass + + # def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + # # self.responseBody = error + # # self.responseStatus = error.status_code # type: ignore[attr-defined] + # """Do nothing.""" + # pass + + # def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + # # self.responseBody = error + # # self.responseStatus = error.status_code # type: ignore[attr-defined] + # """Do nothing.""" + # pass + + # def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + # # self.responseBody = error + # # self.responseStatus = error.status_code # type: ignore[attr-defined] + # pass + + # def on_text(self, text: str, **kwargs: Any) -> None: + # pass + + # def on_agent_finish(self, finish: Any, **kwargs: Any) -> None: + # pass + + # def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + # # self.streamingMode = True + # """Do nothing.""" + # pass + + # def on_tool_start( + # self, + # serialized: Dict[str, Any], + # input_str: str, + # **kwargs: Any, + # ) -> None: + # pass + + # def on_agent_action(self, action: Any, **kwargs: Any) -> Any: + # """Do nothing.""" + # pass + + # def on_tool_end( + # self, + # output: Any, + # observation_prefix: Optional[str] = None, + # llm_prefix: Optional[str] = None, + # **kwargs: Any, + # ) -> None: + # pass + +''' From 4208593f5ca0ca32a89199bc2d84c35616633c7e Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Sat, 10 Aug 2024 16:11:41 +0530 Subject: [PATCH 4/8] feat: response time added for each span --- portkey_ai/api_resources/apis/logger.py | 2 +- .../llama_index/portkey_llama_callback.py | 83 ++++++++++++++----- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/portkey_ai/api_resources/apis/logger.py b/portkey_ai/api_resources/apis/logger.py index 71283c7f..76c2ff62 100644 --- a/portkey_ai/api_resources/apis/logger.py +++ b/portkey_ai/api_resources/apis/logger.py @@ -24,7 +24,7 @@ def __init__( def log( self, - log_object: dict, + log_object, ): response = requests.post( url=self.url, data=json.dumps(log_object), headers=self.headers diff --git a/portkey_ai/llms/llama_index/portkey_llama_callback.py b/portkey_ai/llms/llama_index/portkey_llama_callback.py index bef1bd04..f5a47009 100644 --- a/portkey_ai/llms/llama_index/portkey_llama_callback.py +++ b/portkey_ai/llms/llama_index/portkey_llama_callback.py @@ -24,9 +24,6 @@ class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler): - startTimestamp: int = 0 - endTimestamp: float = 0 - def __init__( self, api_key: str, @@ -55,13 +52,12 @@ def __init__( self.prompt_tokens = 0 self.token_llm = 0 - self.log_object: Dict[str, Any] = {} + self.log_object: Any = [] self.prompt_records: Any = [] self.request: Any = {} self.response: Any = {} - self.responseTime: int = 0 self.streamingMode: bool = False self.event_map: Any = {} @@ -84,6 +80,7 @@ def on_event_start( # type: ignore span_id = str(event_id) parent_span_id = parent_id span_name = event_type + start_time = int(datetime.now().timestamp()) if parent_id == "root": parent_span_id = self.main_span_id @@ -111,6 +108,7 @@ def on_event_start( # type: ignore "span_name": span_name.value, "trace_id": self.global_trace_id, "request": request_payload, + "start_time": start_time, } self.event_map[span_id] = start_event_information @@ -131,7 +129,7 @@ def on_event_end( elif event_type == "agent_step": response_payload = self.agent_step_event_end(payload, event_id) elif event_type == "function_call": - response_payload = self.function_call_event_end(payload) + response_payload = self.function_call_event_end(payload, event_id) elif event_type == "query": response_payload = self.query_event_end(payload, event_id) elif event_type == "retrieve": @@ -150,9 +148,7 @@ def start_trace(self, trace_id: Optional[str] = None) -> None: if trace_id == "index_construction": self.global_trace_id = str(uuid4()) - self.main_span_id = str(uuid4()) - - self.startTimestamp = int(datetime.now().timestamp()) + self.main_span_id = "" def end_trace( self, @@ -161,7 +157,7 @@ def end_trace( ) -> None: """Run when an overall trace is exited.""" - self.log_object = {"data": self.event_array} + self.log_object = self.event_array self.portkey_logger.log(log_object=self.log_object) self.event_array = [] @@ -193,8 +189,10 @@ def llm_event_start(self, payload: Any) -> Any: return self.request def llm_event_end(self, payload: Any, event_id) -> Any: - self.endTimestamp = float(datetime.now().timestamp()) - responseTime = self.endTimestamp - self.startTimestamp + if event_id in self.event_map: + event = self.event_map[event_id] + start_time = event["start_time"] + self.response = {} data = payload.get("response", {}) @@ -228,10 +226,14 @@ def llm_event_end(self, payload: Any, event_id) -> Any: self.response["body"].update({"id": event_id}) self.response["body"].update({"created": int(time.time())}) self.response["body"].update({"model": data.raw.get("model", "")}) - self.response["time"] = int(responseTime * 1000) self.response["headers"] = {} self.response["streamingMode"] = self.streamingMode + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + + self.response["response_time"] = total_time + return self.response # ------------------------------------------------------ # @@ -255,13 +257,13 @@ def embedding_event_start(self, payload: Any) -> Any: def embedding_event_end(self, payload: Any, event_id) -> Any: if event_id in self.event_map: event = self.event_map[event_id] - event["request"]["body"]["input"] = payload.get("chunks", "") - # event["request"]["body"]["input"] = "...INPUT..." + # event["request"]["body"]["input"] = payload.get("chunks", "") + # Setting as ...INPUT... to avoid logging the entire data input file + event["request"]["body"]["input"] = "...INPUT..." - self.response = {} + start_time = event["start_time"] - self.endTimestamp = float(datetime.now().timestamp()) - responseTime = self.endTimestamp - self.startTimestamp + self.response = {} chunk_str = str(payload.get("chunks", "")) embd_str = str(payload.get("embeddings", "")) @@ -286,7 +288,11 @@ def embedding_event_end(self, payload: Any, event_id) -> Any: } ) self.response["headers"] = {} - self.response["time"] = int(responseTime * 1000) + + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + + self.response["response_time"] = total_time return self.response @@ -296,9 +302,16 @@ def agent_step_event_start(self, payload: Any) -> Any: return data def agent_step_event_end(self, payload: Any, event_id) -> Any: + if event_id in self.event_map: + event = self.event_map[event_id] + start_time = event["start_time"] data = self.serialize(payload) json.dumps(data) result = self.transform_agent_step_end(data) + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + + result["response_time"] = total_time return result # ------------------------------------------------------ # @@ -306,10 +319,17 @@ def function_call_event_start(self, payload: Any) -> Any: result = self.transform_function_call_start(payload) return result - def function_call_event_end(self, payload: Any) -> Any: + def function_call_event_end(self, payload: Any, event_id) -> Any: + if event_id in self.event_map: + event = self.event_map[event_id] + start_time = event["start_time"] data = self.serialize(payload) json.dumps(data) result = self.transform_function_call_end(data) + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + + result["response_time"] = total_time return result # ------------------------------------------------------ # @@ -318,9 +338,16 @@ def query_event_start(self, payload: Any) -> Any: return data def query_event_end(self, payload: Any, event_id) -> Any: + if event_id in self.event_map: + event = self.event_map[event_id] + start_time = event["start_time"] data = self.serialize(payload) json.dumps(data) result = self.transform_query_end(data) + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + + result["response_time"] = total_time return result # ------------------------------------------------------ # @@ -329,9 +356,17 @@ def retrieve_event_start(self, payload: Any) -> Any: return data def retrieve_event_end(self, payload: Any, event_id) -> Any: + if event_id in self.event_map: + event = self.event_map[event_id] + start_time = event["start_time"] + data = self.serialize(payload) json.dumps(data) result = self.transform_retrieve_end(data) + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + + result["response_time"] = total_time return result # ------------------------------------------------------ # @@ -342,7 +377,15 @@ def templating_event_start(self, payload: Any) -> Any: return result def templating_event_end(self, payload: Any, event_id) -> Any: + if event_id in self.event_map: + event = self.event_map[event_id] + start_time = event["start_time"] result = self.transform_templating_end(event_id) + + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + + result["response_time"] = total_time return result # ------------------------------------------------------ # From 1f35eeaddf87996ac96dece58b107f78160d13ae Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Sat, 10 Aug 2024 17:25:51 +0530 Subject: [PATCH 5/8] feat: response time in spans + clean up --- .../langchain/portkey_langchain_callback.py | 391 +++--------------- 1 file changed, 48 insertions(+), 343 deletions(-) diff --git a/portkey_ai/llms/langchain/portkey_langchain_callback.py b/portkey_ai/llms/langchain/portkey_langchain_callback.py index b37f373e..2c521b7a 100644 --- a/portkey_ai/llms/langchain/portkey_langchain_callback.py +++ b/portkey_ai/llms/langchain/portkey_langchain_callback.py @@ -5,6 +5,7 @@ from uuid import UUID, uuid4 from portkey_ai.api_resources.apis.logger import Logger import re +from datetime import datetime try: from langchain_core.callbacks.base import BaseCallbackHandler @@ -18,21 +19,17 @@ def __init__( api_key: str, ) -> None: super().__init__() - self.startTimestamp: float = 0 - self.endTimestamp: float = 0 self.api_key = api_key self.portkey_logger = Logger(api_key=api_key) - self.log_object: Dict[str, Any] = {} + self.log_object: Any = [] self.prompt_records: Any = [] self.request: Any = {} self.response: Any = {} - # self.responseHeaders: Dict[str, Any] = {} - self.responseBody: Any = None self.responseStatus: int = 0 self.streamingMode: bool = False @@ -62,9 +59,6 @@ def on_llm_start( you're implementing a handler for a chat model, you should use on_chat_model_start instead. """ - # print("on_llm_start") - # print("on_llm_start: run_id: ", run_id) - # print("on_llm_start: parent_run_id: ", parent_run_id) request_payload = self.on_llm_start_transformer( serialized, prompts, kwargs=kwargs @@ -98,13 +92,6 @@ def on_chat_model_start( # NotImplementedError is thrown intentionally # Callback handler will fall back to on_llm_start if this is exception is thrown - # print("on_chat_model_start") - # print("on_chat_model_start: serialized: ", serialized) - # print("on_chat_model_start: messages: ", messages) - # print("on_chat_model_start: kwargs: ", kwargs) - # print("on_chat_model_start: run_id: ", run_id) - # print("on_chat_model_start: parent_run_id: ", parent_run_id) - request_payload = self.on_chat_model_start_transformer( serialized, messages, kwargs=kwargs ) @@ -132,12 +119,17 @@ def on_llm_end( **kwargs: Any, ) -> None: """Run when LLM ends running.""" - # print("on_llm_end: ") - # print(f"on_llm_end: run_id: {run_id}") - # print(f"on_llm_end: parent_run_id: {parent_run_id}") + + start_time = self.event_map["llm_start_" + str(run_id)]["start_time"] + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 response_payload = self.on_llm_end_transformer(response, kwargs=kwargs) self.event_map["llm_start_" + str(run_id)]["response"] = response_payload + self.event_map["llm_start_" + str(run_id)]["response"][ + "response_time" + ] = total_time + self.event_array.append(self.event_map["llm_start_" + str(run_id)]) pass @@ -154,16 +146,9 @@ def on_chain_start( ) -> None: """Run when chain starts running.""" - # print("on_chain_start: ") - # print(f"on_chain_start: run_id: {run_id}") - # print(f"on_chain_start: parent_run_id: {parent_run_id}") - # print("on_chain_start: serialized: ", serialized) - # print("on_chain_start: inputs: ", inputs) - # print("on_chain_start: kwargs: ", kwargs) - if parent_run_id is None: self.global_trace_id = str(uuid4()) - self.main_span_id = str(uuid4()) + self.main_span_id = "" parent_span_id = ( self.main_span_id if parent_run_id is None else str(parent_run_id) @@ -181,9 +166,6 @@ def on_chain_start( ) self.event_map["chain_start_" + str(run_id)] = info_obj - # print( - # "on_chain_start: eventMap: ", self.event_map["chain_start_" + str(run_id)] - # ) pass def on_chain_end( @@ -196,19 +178,22 @@ def on_chain_end( **kwargs: Any, ) -> None: """Run when chain ends running.""" - # print("on_chain_end: ") - # print(f"on_chain_end: run_id: {run_id}") - # print(f"on_chain_end: parent_run_id: {parent_run_id}") + + start_time = self.event_map["chain_start_" + str(run_id)]["start_time"] + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 response_payload = self.on_chain_end_transformer(outputs) self.event_map["chain_start_" + str(run_id)]["response"] = response_payload + self.event_map["chain_start_" + str(run_id)]["response"][ + "response_time" + ] = total_time + self.event_array.append(self.event_map["chain_start_" + str(run_id)]) if parent_run_id is None: - # print("END OF THE ENTIRE CHAIN") - print("FINAL EVENT ARRAY: ", self.event_array) - self.log_object = {"data": self.event_array} + self.log_object = self.event_array self.portkey_logger.log(log_object=self.log_object) self.event_array = [] @@ -229,9 +214,6 @@ def on_tool_start( **kwargs: Any, ) -> None: """Run when tool starts running.""" - print("on_tool_start: ") - # print(f"on_tool_start: run_id: {run_id}") - # print(f"on_tool_start: parent_run_id: {parent_run_id}") request_payload = self.on_tool_start_transformer(serialized, input_str, inputs) info_obj = self.start_event_information( run_id, @@ -253,11 +235,16 @@ def on_tool_end( **kwargs: Any, ) -> None: """Run when tool ends running.""" - # print("on_tool_end: ") - # print(f"on_tool_end: run_id: {run_id}") - # print(f"on_tool_end: parent_run_id: {parent_run_id}") + + start_time = self.event_map["tool_start_" + str(run_id)]["start_time"] + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + response_payload = self.on_tool_end_transformer(output) self.event_map["tool_start_" + str(run_id)]["response"] = response_payload + self.event_map["tool_start_" + str(run_id)]["response"][ + "response_time" + ] = total_time self.event_array.append(self.event_map["tool_start_" + str(run_id)]) pass @@ -271,9 +258,6 @@ def on_text( # Do we need to log this or not? This is just formatting of the te **kwargs: Any, ) -> None: """Run on arbitrary text.""" - # print("on_text: ") - # print(f"on_text: run_id: {run_id}") - # print(f"on_text: parent_run_id: {parent_run_id}") parent_span_id = ( self.main_span_id if parent_run_id is None else str(parent_run_id) @@ -296,9 +280,6 @@ def on_agent_action( **kwargs: Any, ) -> None: """Run on agent action.""" - # print("on_agent_action: ") - # print(f"on_agent_action: run_id: {run_id}") - # print(f"on_agent_action: parent_run_id: {parent_run_id}") parent_span_id = ( self.main_span_id if parent_run_id is None else str(parent_run_id) @@ -324,13 +305,17 @@ def on_agent_finish( **kwargs: Any, ) -> None: """Run on agent end.""" - # print("on_agent_finish: ") - # print(f"on_agent_finish: run_id: {run_id}") - # print(f"on_agent_finish: parent_run_id: {parent_run_id}") + + start_time = self.event_map["agent_action_" + str(run_id)]["start_time"] + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 (self.main_span_id if parent_run_id is None else str(parent_run_id)) response_payload = self.on_agent_finish_transformer(finish) self.event_map["agent_action_" + str(run_id)]["response"] = response_payload + self.event_map["agent_action_" + str(run_id)]["response"][ + "response_time" + ] = total_time self.event_array.append(self.event_map["agent_action_" + str(run_id)]) pass @@ -346,9 +331,6 @@ def on_retriever_start( **kwargs: Any, ) -> None: """Run on retriever start.""" - # print("on_retriever_start: ") - # print(f"on_retriever_start: run_id: {run_id}") - # print(f"on_retriever_start: parent_run_id: {parent_run_id}") pass def on_retriever_end( @@ -361,122 +343,20 @@ def on_retriever_end( **kwargs: Any, ) -> None: """Run on retriever end.""" - # print("on_retriever_end: ") - # print(f"on_retriever_end: run_id: {run_id}") - # print(f"on_retriever_end: parent_run_id: {parent_run_id}") pass - ''' - - # def on_llm_new_token( - # self, - # token: str, - # *, - # chunk: Optional[Union[Any, Any]] = None, - # run_id: UUID, - # parent_run_id: Optional[UUID] = None, - # tags: Optional[List[str]] = None, - # **kwargs: Any, - # ) -> None: - # """Run on new LLM token. Only available when streaming is enabled.""" - # print("on_llm_new_token: ") - # print(f"on_llm_new_token: run_id: {run_id}") - # print(f"on_llm_new_token: parent_run_id: {parent_run_id}") - # pass - - # def on_retriever_error( - # self, - # error: BaseException, - # *, - # run_id: UUID, - # parent_run_id: Optional[UUID] = None, - # tags: Optional[List[str]] = None, - # **kwargs: Any, - # ) -> None: - # """Run on retriever error.""" - # print("on_retriever_error: ") - # print(f"on_retriever_error: run_id: {run_id}") - # print(f"on_retriever_error: parent_run_id: {parent_run_id}") - # pass - - # def on_llm_error( - # self, - # error: BaseException, - # *, - # run_id: UUID, - # parent_run_id: Optional[UUID] = None, - # tags: Optional[List[str]] = None, - # **kwargs: Any, - # ) -> None: - # """Run when LLM errors. - - # Args: - # error: The error that occurred. - # kwargs (Any): Additional keyword arguments. - # - response (Any): The response which was generated before - # the error occurred. - # """ - # print("on_llm_error: ") - # print(f"on_llm_error: run_id: {run_id}") - # print(f"on_llm_error: parent_run_id: {parent_run_id}") - # pass - - # def on_retry( - # self, - # retry_state: Any, - # *, - # run_id: UUID, - # parent_run_id: Optional[UUID] = None, - # **kwargs: Any, - # ) -> Any: - # """Run on a retry event.""" - # print("on_retry: ") - # # print(f"on_retry: run_id: {run_id}") - # # print(f"on_retry: parent_run_id: {parent_run_id}") - # pass - - # def on_tool_error( - # self, - # error: BaseException, - # *, - # run_id: UUID, - # parent_run_id: Optional[UUID] = None, - # tags: Optional[List[str]] = None, - # **kwargs: Any, - # ) -> None: - # """Run when tool errors.""" - # print("on_tool_error: ") - # print(f"on_tool_error: run_id: {run_id}") - # print(f"on_tool_error: parent_run_id: {parent_run_id}") - # pass - - # def on_chain_error( - # self, - # error: BaseException, - # *, - # run_id: UUID, - # parent_run_id: Optional[UUID] = None, - # tags: Optional[List[str]] = None, - # **kwargs: Any, - # ) -> None: - # """Run when chain errors.""" - # print("on_chain_error: ") - # print(f"on_chain_error: run_id: {run_id}") - # print(f"on_chain_error: parent_run_id: {parent_run_id}") - # pass - - ''' - # -------------- Helpers ------------------------------------------ def start_event_information( self, span_id, parent_span_id, span_name, trace_id, request_payload ): + start_time = int(datetime.now().timestamp()) return { "span_id": str(span_id), "parent_span_id": str(parent_span_id), "span_name": span_name, "trace_id": trace_id, "request": request_payload, + "start_time": start_time, } def serialize(self, obj): @@ -525,15 +405,11 @@ def extract_response_format(self, content: str) -> List[str]: return response_format_lines return [] - # ---------------------------------------------------------------------------- + # ----------------------------------------------------------------- # ------ Event Transformers ------ def on_llm_start_transformer(self, serialized, prompts, kwargs): - # print("on_llm_start_transformer: serialized: ", serialized) - # print("on_llm_start_transformer: prompts: ", prompts) - # print("on_llm_start_transformer: kwargs: ", kwargs) - try: result = {"messages": []} for entry in prompts: @@ -560,12 +436,6 @@ def on_llm_start_transformer(self, serialized, prompts, kwargs): request = {} - # print("RESPONSE MESSAGES: ", result["messages"]) - - # startTimestamp = float(datetime.now().timestamp()) - - # streamingMode = kwargs.get("invocation_params", False).get("stream", False) - request["method"] = "POST" request["url"] = serialized.get("kwargs", "").get( "base_url", "chat/completions" @@ -576,8 +446,7 @@ def on_llm_start_transformer(self, serialized, prompts, kwargs): request["body"] = {"messages": result["messages"]} request["body"].update(kwargs.get("invocation_params", {})) return request - except Exception as e: - print("on_llm_start_transformer: Error: ", e) + except Exception: return { "serialized": serialized, "prompts": prompts, @@ -587,14 +456,8 @@ def on_llm_start_transformer(self, serialized, prompts, kwargs): def on_llm_end_transformer(self, response, kwargs): try: response_obj = {} - # self.endTimestamp = float(datetime.now().timestamp()) - # responseTime = self.endTimestamp - self.startTimestamp usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr] - # self.response["status"] = ( - # 200 if self.responseStatus == 0 else self.responseStatus - # ) - response_obj["body"] = { "choices": [ { @@ -613,26 +476,19 @@ def on_llm_end_transformer(self, response, kwargs): response_obj["body"].update({"created": int(time.time())}) response_obj["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501 response_obj["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501 - # response["time"] = int(responseTime * 1000) response_obj["headers"] = {} - # response["streamingMode"] = streamingMode return response_obj - except Exception as e: - print("on_llm_end_transformer: Error: ", e) + except Exception: return {"response": response, "kwargs": kwargs} def on_chain_start_transformer(self, serialized, input, kwargs): try: - # print("on_chain_start_transformer: serialized: ", serialized) - # print("on_chain_start_transformer: input: ", input) - # print("on_chain_start_transformer: kwargs: ", kwargs) name = kwargs["name"] return { "name": name, "input": json.dumps(input), } - except Exception as e: - print("Error in on_chain_start_transformer:", str(e)) + except Exception: return {"serialized": serialized, "input": input, "kwargs": kwargs} def on_chain_end_transformer(self, output): @@ -657,16 +513,13 @@ def on_chain_end_transformer(self, output): return {"output": structured_data} - except Exception as e: - print("Error in on_chain_end_transformer:", str(e)) + except Exception: return {"output": output} def on_text_transformer(self, text): try: - # print("on_text_transformer: text: ", text) return {"text": text} - except Exception as e: - print("Error in on_text_transformer:", str(e)) + except Exception: return {"text": text} def on_chat_model_start_transformer(self, serialized, messages, kwargs): @@ -680,8 +533,7 @@ def on_chat_model_start_transformer(self, serialized, messages, kwargs): "invocation_params": invocation_params, "messages": message_obj, } - except Exception as e: - print("Error in on_chat_model_start_transformer:", str(e)) + except Exception: return {"serialized": serialized, "messages": message_obj, "kwargs": kwargs} def on_agent_action_transformer(self, action): @@ -691,173 +543,26 @@ def on_agent_action_transformer(self, action): tool_input = action["tool_input"] log = action["log"] return {"tool": tool, "tool_input": tool_input, "log": log} - except Exception as e: - print("Error in on_agent_action_transformer:", str(e)) + except Exception: return {"action": action} def on_agent_finish_transformer(self, finish): try: - # print("on_agent_finish_transformer: finish: ", finish) finish = self.serialize(finish) return_values = finish["return_values"] log = finish["log"] return {"return_values": return_values, "log": log} - except Exception as e: - print("Error in on_agent_finish_transformer:", str(e)) + except Exception: return {"finish": finish} def on_tool_start_transformer(self, serialized, input_str, inputs): try: return {"serialized": serialized, "input_str": input_str, "inputs": inputs} - except Exception as e: - print("Error in on_tool_start_transformer: ", str(e)) + except Exception: return {"serialized": serialized, "input_str": input_str, "inputs": inputs} def on_tool_end_transformer(self, output): try: return {"output": output} - except Exception as e: - print("Error in on_tool_end_transformer: ", str(e)) + except Exception: return {"output": output} - - -# flake8: noqa: E501 -''' - - # def on_llm_start( - # self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - # ) -> None: - # print("on_llm_start: ") - # # for prompt in prompts: - # # messages = prompt.split("\n") - # # for message in messages: - # # role, content = message.split(":", 1) - # # self.prompt_records.append( - # # {"role": role.lower(), "content": content.strip()} - # # ) - - # # self.startTimestamp = float(datetime.now().timestamp()) - - # # self.streamingMode = kwargs.get("invocation_params", False).get("stream", False) # noqa: E501 - - # # self.request["method"] = "POST" - # # self.request["url"] = serialized.get("kwargs", "").get( - # # "base_url", "chat/completions" - # # ) - # # self.request["provider"] = serialized["id"][2] - # # self.request["headers"] = serialized.get("kwargs", {}).get( - # # "default_headers", {} - # # ) - # # self.request["headers"].update({"provider": serialized["id"][2]}) - # # self.request["body"] = {"messages": self.prompt_records} - # # self.request["body"].update({**kwargs.get("invocation_params", {})}) - # pass - - # def on_chain_start( - # self, - # serialized: Dict[str, Any], - # inputs: Dict[str, Any], - # **kwargs: Any, - # ) -> None: - # """Run when chain starts running.""" - # pass - - # def on_llm_end(self, response: Any, **kwargs: Any) -> None: - # # self.endTimestamp = float(datetime.now().timestamp()) - # # responseTime = self.endTimestamp - self.startTimestamp - - # # usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr] # noqa: E501 - - # # self.response["status"] = ( - # # 200 if self.responseStatus == 0 else self.responseStatus - # # ) - # # self.response["body"] = { - # # "choices": [ - # # { - # # "index": 0, - # # "message": { - # # "role": "assistant", - # # "content": response.generations[0][0].text, - # # }, - # # "logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501 - # # "finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501 - # # } - # # ] - # # } - # # self.response["body"].update({"usage": usage}) - # # self.response["body"].update({"id": str(kwargs.get("run_id", ""))}) - # # self.response["body"].update({"created": int(time.time())}) - # # self.response["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501 - # # self.response["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501 - # # self.response["time"] = int(responseTime * 1000) - # # self.response["headers"] = {} - # # self.response["streamingMode"] = self.streamingMode - - # # self.log_object.update( - # # { - # # "request": self.request, - # # "response": self.response, - # # } - # # ) - - # # self.portkey_logger.log(log_object=self.log_object) - # pass - - # def on_chain_end( - # self, - # outputs: Dict[str, Any], - # **kwargs: Any, - # ) -> None: - # """Run when chain ends running.""" - # pass - - # def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: - # # self.responseBody = error - # # self.responseStatus = error.status_code # type: ignore[attr-defined] - # """Do nothing.""" - # pass - - # def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: - # # self.responseBody = error - # # self.responseStatus = error.status_code # type: ignore[attr-defined] - # """Do nothing.""" - # pass - - # def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: - # # self.responseBody = error - # # self.responseStatus = error.status_code # type: ignore[attr-defined] - # pass - - # def on_text(self, text: str, **kwargs: Any) -> None: - # pass - - # def on_agent_finish(self, finish: Any, **kwargs: Any) -> None: - # pass - - # def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - # # self.streamingMode = True - # """Do nothing.""" - # pass - - # def on_tool_start( - # self, - # serialized: Dict[str, Any], - # input_str: str, - # **kwargs: Any, - # ) -> None: - # pass - - # def on_agent_action(self, action: Any, **kwargs: Any) -> Any: - # """Do nothing.""" - # pass - - # def on_tool_end( - # self, - # output: Any, - # observation_prefix: Optional[str] = None, - # llm_prefix: Optional[str] = None, - # **kwargs: Any, - # ) -> None: - # pass - -''' From 550538502bab6fed0bf018bd9827ffd0b4c7ed40 Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Mon, 12 Aug 2024 11:32:57 +0530 Subject: [PATCH 6/8] fix: add global trace id for llamaindex --- portkey_ai/llms/llama_index/portkey_llama_callback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/portkey_ai/llms/llama_index/portkey_llama_callback.py b/portkey_ai/llms/llama_index/portkey_llama_callback.py index f5a47009..dada301e 100644 --- a/portkey_ai/llms/llama_index/portkey_llama_callback.py +++ b/portkey_ai/llms/llama_index/portkey_llama_callback.py @@ -58,6 +58,7 @@ def __init__( self.request: Any = {} self.response: Any = {} + self.global_trace_id: str = "" self.streamingMode: bool = False self.event_map: Any = {} From 6693afed4863b4c35e944d553da01f9ed7ea3184 Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Mon, 12 Aug 2024 13:35:26 +0530 Subject: [PATCH 7/8] feat: metadata accpeted for filtering --- .../langchain/portkey_langchain_callback.py | 26 ++++++++++++++++--- .../llama_index/portkey_llama_callback.py | 6 ++++- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/portkey_ai/llms/langchain/portkey_langchain_callback.py b/portkey_ai/llms/langchain/portkey_langchain_callback.py index 2c521b7a..6b375cf1 100644 --- a/portkey_ai/llms/langchain/portkey_langchain_callback.py +++ b/portkey_ai/llms/langchain/portkey_langchain_callback.py @@ -17,10 +17,12 @@ class PortkeyLangchain(BaseCallbackHandler): def __init__( self, api_key: str, + metadata: Optional[Dict[str, Any]] = {}, ) -> None: super().__init__() self.api_key = api_key + self.metadata = metadata self.portkey_logger = Logger(api_key=api_key) @@ -69,6 +71,7 @@ def on_llm_start( "llm_start", self.global_trace_id, request_payload, + self.metadata, ) self.event_map["llm_start_" + str(run_id)] = info_obj pass @@ -101,6 +104,7 @@ def on_chat_model_start( "chat_model_start", self.global_trace_id, request_payload, + self.metadata, ) self.event_map["chat_model_start_" + str(run_id)] = info_obj self.event_array.append(self.event_map["chat_model_start_" + str(run_id)]) @@ -147,9 +151,8 @@ def on_chain_start( """Run when chain starts running.""" if parent_run_id is None: - self.global_trace_id = str(uuid4()) + self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr] self.main_span_id = "" - parent_span_id = ( self.main_span_id if parent_run_id is None else str(parent_run_id) ) @@ -163,6 +166,7 @@ def on_chain_start( "chain_start", self.global_trace_id, request_payload, + self.metadata, ) self.event_map["chain_start_" + str(run_id)] = info_obj @@ -221,6 +225,7 @@ def on_tool_start( "tool_start", self.global_trace_id, request_payload, + self.metadata, ) self.event_map["tool_start_" + str(run_id)] = info_obj pass @@ -264,7 +269,12 @@ def on_text( # Do we need to log this or not? This is just formatting of the te ) request_payload = self.on_text_transformer(text) info_obj = self.start_event_information( - run_id, parent_span_id, "text", self.global_trace_id, request_payload + run_id, + parent_span_id, + "text", + self.global_trace_id, + request_payload, + self.metadata, ) self.event_map["text_" + str(run_id)] = info_obj self.event_array.append(self.event_map["text_" + str(run_id)]) @@ -291,6 +301,7 @@ def on_agent_action( "agent_action", self.global_trace_id, request_payload, + self.metadata, ) self.event_map["agent_action_" + str(run_id)] = info_obj pass @@ -347,7 +358,13 @@ def on_retriever_end( # -------------- Helpers ------------------------------------------ def start_event_information( - self, span_id, parent_span_id, span_name, trace_id, request_payload + self, + span_id, + parent_span_id, + span_name, + trace_id, + request_payload, + metadata=None, ): start_time = int(datetime.now().timestamp()) return { @@ -357,6 +374,7 @@ def start_event_information( "trace_id": trace_id, "request": request_payload, "start_time": start_time, + "metadata": metadata, } def serialize(self, obj): diff --git a/portkey_ai/llms/llama_index/portkey_llama_callback.py b/portkey_ai/llms/llama_index/portkey_llama_callback.py index dada301e..0d5c77a0 100644 --- a/portkey_ai/llms/llama_index/portkey_llama_callback.py +++ b/portkey_ai/llms/llama_index/portkey_llama_callback.py @@ -27,6 +27,7 @@ class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler): def __init__( self, api_key: str, + metadata: Optional[Dict[str, Any]] = {}, ) -> None: super().__init__( event_starts_to_ignore=[ @@ -44,6 +45,7 @@ def __init__( ) self.api_key = api_key + self.metadata = metadata self.portkey_logger = Logger(api_key=api_key) @@ -110,6 +112,7 @@ def on_event_start( # type: ignore "trace_id": self.global_trace_id, "request": request_payload, "start_time": start_time, + "metadata": self.metadata, } self.event_map[span_id] = start_event_information @@ -141,13 +144,14 @@ def on_event_end( response_payload = payload self.event_map[span_id]["response"] = response_payload + self.event_array.append(self.event_map[span_id]) def start_trace(self, trace_id: Optional[str] = None) -> None: """Run when an overall trace is launched.""" if trace_id == "index_construction": - self.global_trace_id = str(uuid4()) + self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr] self.main_span_id = "" From fefdf04fb037444fde449c06babf6eba8486e2eb Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Mon, 12 Aug 2024 13:53:28 +0530 Subject: [PATCH 8/8] fix: clean up --- portkey_ai/llms/langchain/portkey_langchain_callback.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/portkey_ai/llms/langchain/portkey_langchain_callback.py b/portkey_ai/llms/langchain/portkey_langchain_callback.py index 6b375cf1..c45864f9 100644 --- a/portkey_ai/llms/langchain/portkey_langchain_callback.py +++ b/portkey_ai/llms/langchain/portkey_langchain_callback.py @@ -253,7 +253,7 @@ def on_tool_end( self.event_array.append(self.event_map["tool_start_" + str(run_id)]) pass - def on_text( # Do we need to log this or not? This is just formatting of the text + def on_text( self, text: str, *, @@ -321,7 +321,6 @@ def on_agent_finish( end_time = int(datetime.now().timestamp()) total_time = (end_time - start_time) * 1000 - (self.main_span_id if parent_run_id is None else str(parent_run_id)) response_payload = self.on_agent_finish_transformer(finish) self.event_map["agent_action_" + str(run_id)]["response"] = response_payload self.event_map["agent_action_" + str(run_id)]["response"][