-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
langchain llamaindex callback handler #127
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
ae286fb
feat: logger function for langchain
csgulati09 25dcec9
feat: basic llama index callbackHandler template
csgulati09 2d83b4f
feat: basic llama index callbackHandler template
csgulati09 cf33db1
feat: updated req res object for logging
csgulati09 cd543c3
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 4fe6eaf
feat: clean up for langchain and logger
csgulati09 77f0340
feat: llama index callback handler
csgulati09 c361a54
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 e2d39ea
feat: update langchain and llamaindex handlers + logger file
csgulati09 2a8256a
fix: linting issues + code clean up
csgulati09 fbbfdac
fix: llamaindex init file
csgulati09 bab40e3
fix: base url for langchain
csgulati09 1b3f158
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 2904274
fix:linting issues
csgulati09 dcc6b6e
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 627e754
feat: logger for langchain and llamaindex
csgulati09 4ff1b08
fix: linitng issues
csgulati09 ded6324
fix: file structure for callbackhanders
csgulati09 deadd8b
feat: test cases for langchain and llamaindex
csgulati09 7a35942
fix: linting issues
csgulati09 046b69a
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 57e3755
fix: models.json for tests and base url to prod
csgulati09 a11eced
fix: linting issues + conditional import
csgulati09 983b646
fix: extra dependency for conditional import
csgulati09 614cd58
fix: tested conditional import + init files fixed
csgulati09 8864bac
fix: token count for llamaindex
csgulati09 57da8e2
fix: restructuring setup.cfg
csgulati09 3940880
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 5d8552e
fix: import statement for llm test cases
csgulati09 244b7fe
feat: prompt tokens for llamaindex
csgulati09 7ebd68b
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 84483b1
fix: type + make file command
csgulati09 9d973b7
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import json | ||
import os | ||
from typing import Optional | ||
import requests | ||
|
||
from portkey_ai.api_resources.global_constants import PORTKEY_BASE_URL | ||
|
||
|
||
class Logger: | ||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
) -> None: | ||
api_key = api_key or os.getenv("PORTKEY_API_KEY") | ||
if api_key is None: | ||
raise ValueError("API key is required to use the Logger API") | ||
|
||
self.headers = { | ||
"Content-Type": "application/json", | ||
"x-portkey-api-key": api_key, | ||
} | ||
|
||
self.url = PORTKEY_BASE_URL + "/logs" | ||
|
||
def log( | ||
self, | ||
log_object: dict, | ||
): | ||
response = requests.post( | ||
url=self.url, data=json.dumps(log_object), headers=self.headers | ||
) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .chat import ChatPortkey | ||
from .completion import PortkeyLLM | ||
from .portkey_langchain_callback import PortkeyLangchain | ||
|
||
__all__ = ["ChatPortkey", "PortkeyLLM"] | ||
__all__ = ["ChatPortkey", "PortkeyLLM", "PortkeyLangchain"] |
170 changes: 170 additions & 0 deletions
170
portkey_ai/llms/langchain/portkey_langchain_callback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from datetime import datetime | ||
import time | ||
from typing import Any, Dict, List, Optional | ||
from portkey_ai.api_resources.apis.logger import Logger | ||
|
||
try: | ||
from langchain_core.callbacks import BaseCallbackHandler | ||
except ImportError: | ||
raise ImportError("Please pip install langchain-core to use PortkeyLangchain") | ||
|
||
|
||
class PortkeyLangchain(BaseCallbackHandler): | ||
def __init__( | ||
self, | ||
api_key: str, | ||
) -> None: | ||
super().__init__() | ||
self.startTimestamp: float = 0 | ||
self.endTimestamp: float = 0 | ||
|
||
self.api_key = api_key | ||
|
||
self.portkey_logger = Logger(api_key=api_key) | ||
|
||
self.log_object: Dict[str, Any] = {} | ||
self.prompt_records: Any = [] | ||
|
||
self.request: Any = {} | ||
self.response: Any = {} | ||
|
||
# self.responseHeaders: Dict[str, Any] = {} | ||
self.responseBody: Any = None | ||
self.responseStatus: int = 0 | ||
|
||
self.streamingMode: bool = False | ||
|
||
if not api_key: | ||
raise ValueError("Please provide an API key to use PortkeyCallbackHandler") | ||
|
||
def on_llm_start( | ||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||
) -> None: | ||
for prompt in prompts: | ||
messages = prompt.split("\n") | ||
for message in messages: | ||
role, content = message.split(":", 1) | ||
self.prompt_records.append( | ||
{"role": role.lower(), "content": content.strip()} | ||
) | ||
|
||
self.startTimestamp = float(datetime.now().timestamp()) | ||
|
||
self.streamingMode = kwargs.get("invocation_params", False).get("stream", False) | ||
|
||
self.request["method"] = "POST" | ||
self.request["url"] = serialized.get("kwargs", "").get( | ||
"base_url", "chat/completions" | ||
) | ||
self.request["provider"] = serialized["id"][2] | ||
self.request["headers"] = serialized.get("kwargs", {}).get( | ||
"default_headers", {} | ||
) | ||
self.request["headers"].update({"provider": serialized["id"][2]}) | ||
self.request["body"] = {"messages": self.prompt_records} | ||
self.request["body"].update({**kwargs.get("invocation_params", {})}) | ||
|
||
def on_chain_start( | ||
self, | ||
serialized: Dict[str, Any], | ||
inputs: Dict[str, Any], | ||
**kwargs: Any, | ||
) -> None: | ||
"""Run when chain starts running.""" | ||
|
||
def on_llm_end(self, response: Any, **kwargs: Any) -> None: | ||
self.endTimestamp = float(datetime.now().timestamp()) | ||
responseTime = self.endTimestamp - self.startTimestamp | ||
|
||
usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr] | ||
|
||
self.response["status"] = ( | ||
200 if self.responseStatus == 0 else self.responseStatus | ||
) | ||
self.response["body"] = { | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": response.generations[0][0].text, | ||
}, | ||
"logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501 | ||
"finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501 | ||
} | ||
] | ||
} | ||
self.response["body"].update({"usage": usage}) | ||
self.response["body"].update({"id": str(kwargs.get("run_id", ""))}) | ||
self.response["body"].update({"created": int(time.time())}) | ||
self.response["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501 | ||
self.response["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501 | ||
self.response["time"] = int(responseTime * 1000) | ||
self.response["headers"] = {} | ||
self.response["streamingMode"] = self.streamingMode | ||
|
||
self.log_object.update( | ||
{ | ||
"request": self.request, | ||
"response": self.response, | ||
} | ||
) | ||
|
||
self.portkey_logger.log(log_object=self.log_object) | ||
|
||
def on_chain_end( | ||
self, | ||
outputs: Dict[str, Any], | ||
**kwargs: Any, | ||
) -> None: | ||
"""Run when chain ends running.""" | ||
pass | ||
|
||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: | ||
self.responseBody = error | ||
self.responseStatus = error.status_code # type: ignore[attr-defined] | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: | ||
self.responseBody = error | ||
self.responseStatus = error.status_code # type: ignore[attr-defined] | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: | ||
self.responseBody = error | ||
self.responseStatus = error.status_code # type: ignore[attr-defined] | ||
pass | ||
|
||
def on_text(self, text: str, **kwargs: Any) -> None: | ||
pass | ||
|
||
def on_agent_finish(self, finish: Any, **kwargs: Any) -> None: | ||
pass | ||
|
||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | ||
self.streamingMode = True | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_tool_start( | ||
self, | ||
serialized: Dict[str, Any], | ||
input_str: str, | ||
**kwargs: Any, | ||
) -> None: | ||
pass | ||
|
||
def on_agent_action(self, action: Any, **kwargs: Any) -> Any: | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_tool_end( | ||
self, | ||
output: Any, | ||
observation_prefix: Optional[str] = None, | ||
llm_prefix: Optional[str] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .completions import PortkeyLLM | ||
from .portkey_llama_callback import PortkeyLlamaindex | ||
|
||
__all__ = ["PortkeyLLM"] | ||
__all__ = ["PortkeyLlamaindex"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import time | ||
from typing import Any, Dict, List, Optional | ||
from portkey_ai.api_resources.apis.logger import Logger | ||
from datetime import datetime | ||
|
||
try: | ||
from llama_index.core.callbacks.base_handler import ( | ||
BaseCallbackHandler as LlamaIndexBaseCallbackHandler, | ||
) | ||
from llama_index.core.utilities.token_counting import TokenCounter | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError( | ||
"Please install llama-index to use Portkey Callback Handler" | ||
) | ||
except ImportError: | ||
raise ImportError("Please pip install llama-index to use Portkey Callback Handler") | ||
|
||
|
||
class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler): | ||
startTimestamp: int = 0 | ||
endTimestamp: float = 0 | ||
|
||
def __init__( | ||
self, | ||
api_key: str, | ||
) -> None: | ||
super().__init__( | ||
event_starts_to_ignore=[], | ||
event_ends_to_ignore=[], | ||
) | ||
|
||
self.api_key = api_key | ||
|
||
self.portkey_logger = Logger(api_key=api_key) | ||
|
||
self._token_counter = TokenCounter() | ||
self.completion_tokens = 0 | ||
self.prompt_tokens = 0 | ||
self.token_llm = 0 | ||
|
||
self.log_object: Dict[str, Any] = {} | ||
self.prompt_records: Any = [] | ||
|
||
self.request: Any = {} | ||
self.response: Any = {} | ||
|
||
self.responseTime: int = 0 | ||
self.streamingMode: bool = False | ||
|
||
if not api_key: | ||
raise ValueError("Please provide an API key to use PortkeyCallbackHandler") | ||
|
||
def on_event_start( # type: ignore[return] | ||
self, | ||
event_type: Any, | ||
payload: Optional[Dict[str, Any]] = None, | ||
event_id: str = "", | ||
parent_id: str = "", | ||
**kwargs: Any, | ||
) -> str: | ||
"""Run when an event starts and return id of event.""" | ||
|
||
if event_type == "llm": | ||
self.llm_event_start(payload) | ||
|
||
def on_event_end( | ||
self, | ||
event_type: Any, | ||
payload: Optional[Dict[str, Any]] = None, | ||
event_id: str = "", | ||
**kwargs: Any, | ||
) -> None: | ||
"""Run when an event ends.""" | ||
|
||
if event_type == "llm": | ||
self.llm_event_stop(payload, event_id) | ||
|
||
def start_trace(self, trace_id: Optional[str] = None) -> None: | ||
"""Run when an overall trace is launched.""" | ||
self.startTimestamp = int(datetime.now().timestamp()) | ||
|
||
def end_trace( | ||
self, | ||
trace_id: Optional[str] = None, | ||
trace_map: Optional[Dict[str, List[str]]] = None, | ||
) -> None: | ||
"""Run when an overall trace is exited.""" | ||
|
||
def llm_event_start(self, payload: Any) -> None: | ||
if "messages" in payload: | ||
chunks = payload.get("messages", {}) | ||
self.prompt_tokens = self._token_counter.estimate_tokens_in_messages(chunks) | ||
messages = payload.get("messages", {}) | ||
self.prompt_records = [ | ||
{"role": m.role.value, "content": m.content} for m in messages | ||
] | ||
self.request["method"] = "POST" | ||
self.request["url"] = payload.get("serialized", {}).get( | ||
"api_base", "chat/completions" | ||
) | ||
self.request["provider"] = payload.get("serialized", {}).get("class_name", "") | ||
self.request["headers"] = {} | ||
self.request["body"] = {"messages": self.prompt_records} | ||
self.request["body"].update( | ||
{"model": payload.get("serialized", {}).get("model", "")} | ||
) | ||
self.request["body"].update( | ||
{"temperature": payload.get("serialized", {}).get("temperature", "")} | ||
) | ||
|
||
return None | ||
|
||
def llm_event_stop(self, payload: Any, event_id) -> None: | ||
self.endTimestamp = float(datetime.now().timestamp()) | ||
responseTime = self.endTimestamp - self.startTimestamp | ||
|
||
data = payload.get("response", {}) | ||
|
||
chunks = payload.get("messages", {}) | ||
self.completion_tokens = self._token_counter.estimate_tokens_in_messages(chunks) | ||
self.token_llm = self.prompt_tokens + self.completion_tokens | ||
self.response["status"] = 200 | ||
self.response["body"] = { | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": data.message.role.value, | ||
"content": data.message.content, | ||
}, | ||
"logprobs": data.logprobs, | ||
"finish_reason": "done", | ||
} | ||
] | ||
} | ||
self.response["body"].update( | ||
{ | ||
"usage": { | ||
"prompt_tokens": self.prompt_tokens, | ||
"completion_tokens": self.completion_tokens, | ||
"total_tokens": self.token_llm, | ||
} | ||
} | ||
) | ||
self.response["body"].update({"id": event_id}) | ||
self.response["body"].update({"created": int(time.time())}) | ||
self.response["body"].update({"model": data.raw.get("model", "")}) | ||
self.response["time"] = int(responseTime * 1000) | ||
self.response["headers"] = {} | ||
self.response["streamingMode"] = self.streamingMode | ||
|
||
self.log_object.update( | ||
{ | ||
"request": self.request, | ||
"response": self.response, | ||
} | ||
) | ||
self.portkey_logger.log(log_object=self.log_object) | ||
|
||
return None |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the unit of responseTime here? seconds or ms?