Skip to content

Commit

Permalink
Merge pull request #200 from Portkey-AI/fix/llamaIndexPayload
Browse files Browse the repository at this point in the history
None Type LlamaIndex payload
  • Loading branch information
VisargD authored Aug 13, 2024
2 parents 25e4436 + 5165a6f commit db9e921
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 28 deletions.
3 changes: 3 additions & 0 deletions portkey_ai/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .portkey_langchain_callback_handler import LangchainCallbackHandler

__all__ = ["LangchainCallbackHandler"]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
raise ImportError("Please pip install langchain-core to use PortkeyLangchain")


class PortkeyLangchain(BaseCallbackHandler):
class LangchainCallbackHandler(BaseCallbackHandler):
def __init__(
self,
api_key: str,
Expand Down
3 changes: 3 additions & 0 deletions portkey_ai/llamaindex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .portkey_llama_callback_handler import LlamaIndexCallbackHandler

__all__ = ["LlamaIndexCallbackHandler"]
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
raise ImportError("Please pip install llama-index to use Portkey Callback Handler")


class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler):
class LlamaIndexCallbackHandler(LlamaIndexBaseCallbackHandler):
def __init__(
self,
api_key: str,
Expand Down Expand Up @@ -126,30 +126,38 @@ def on_event_end(
"""Run when an event ends."""
span_id = event_id

if event_type == "llm":
response_payload = self.llm_event_end(payload, event_id)
elif event_type == "embedding":
response_payload = self.embedding_event_end(payload, event_id)
elif event_type == "agent_step":
response_payload = self.agent_step_event_end(payload, event_id)
elif event_type == "function_call":
response_payload = self.function_call_event_end(payload, event_id)
elif event_type == "query":
response_payload = self.query_event_end(payload, event_id)
elif event_type == "retrieve":
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
if payload is None:
response_payload = {}
if span_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
response_payload["response_time"] = total_time
else:
response_payload = payload
if event_type == "llm":
response_payload = self.llm_event_end(payload, event_id)
elif event_type == "embedding":
response_payload = self.embedding_event_end(payload, event_id)
elif event_type == "agent_step":
response_payload = self.agent_step_event_end(payload, event_id)
elif event_type == "function_call":
response_payload = self.function_call_event_end(payload, event_id)
elif event_type == "query":
response_payload = self.query_event_end(payload, event_id)
elif event_type == "retrieve":
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
else:
response_payload = payload

self.event_map[span_id]["response"] = response_payload

self.event_array.append(self.event_map[span_id])

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""

if trace_id == "index_construction":
self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr]

Expand Down Expand Up @@ -230,7 +238,7 @@ def llm_event_end(self, payload: Any, event_id) -> Any:
)
self.response["body"].update({"id": event_id})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": data.raw.get("model", "")})
self.response["body"].update({"model": getattr(data, "model", "")})
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

Expand Down
6 changes: 4 additions & 2 deletions portkey_ai/llms/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .chat import ChatPortkey
from .completion import PortkeyLLM
from .portkey_langchain_callback import PortkeyLangchain

__all__ = ["ChatPortkey", "PortkeyLLM", "PortkeyLangchain"]
__all__ = [
"ChatPortkey",
"PortkeyLLM",
]
3 changes: 0 additions & 3 deletions portkey_ai/llms/llama_index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .portkey_llama_callback import PortkeyLlamaindex

__all__ = ["PortkeyLlamaindex"]
4 changes: 2 additions & 2 deletions tests/test_llm_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from tests.utils import read_json_file
from portkey_ai.llms.langchain import PortkeyLangchain
from portkey_ai.langchain import LangchainCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
Expand All @@ -15,7 +15,7 @@


class TestLLMLangchain:
client = PortkeyLangchain
client = LangchainCallbackHandler
parametrize = pytest.mark.parametrize("client", [client], ids=["strict"])
models = read_json_file("./tests/models.json")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_llm_llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from tests.utils import read_json_file
from portkey_ai.llms.llama_index import PortkeyLlamaindex
from portkey_ai.llamaindex import LlamaIndexCallbackHandler


from llama_index.llms.openai import OpenAI
Expand All @@ -24,7 +24,7 @@


class TestLLMLlamaindex:
client = PortkeyLlamaindex
client = LlamaIndexCallbackHandler
parametrize = pytest.mark.parametrize("client", [client], ids=["strict"])
models = read_json_file("./tests/models.json")

Expand Down

0 comments on commit db9e921

Please sign in to comment.