Skip to content

Commit

Permalink
fix: updated feedbacks response, added streaming in post and added ov…
Browse files Browse the repository at this point in the history
…er ride function in utils
  • Loading branch information
noble-varghese committed Dec 1, 2023
1 parent ac4b600 commit 9755993
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
3 changes: 1 addition & 2 deletions portkey_ai/api_resources/apis/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def create(
weight: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
) -> None:
body = dict(trace_id=trace_id, value=value,
weight=weight, metadata=metadata)
body = dict(trace_id=trace_id, value=value, weight=weight, metadata=metadata)
return self._post(
PortkeyApiPaths.FEEDBACK_API,
body=body,
Expand Down
35 changes: 33 additions & 2 deletions portkey_ai/api_resources/apis/post.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Union
from typing import Any, Dict, Union, overload, Literal

from portkey_ai.api_resources.base_client import APIClient

Expand All @@ -10,10 +10,41 @@ class Post(APIResource):
def __init__(self, client: APIClient) -> None:
super().__init__(client)

@overload
def create(
self,
*,
url: str,
stream: Literal[True],
**kwargs,
) -> Stream[Dict[str, Any]]:
...

@overload
def create(
self,
*,
url: str,
stream: Literal[False] = False,
**kwargs,
) -> Dict[str, Any]:
...

@overload
def create(
self,
*,
url: str,
stream: bool = False,
**kwargs,
) -> Union[Dict[str, Any], Stream[Dict[str, Any]]]:
...

def create(
self,
*,
url: str,
stream: bool = False,
**kwargs,
) -> Union[Dict[str, Any], Stream[Dict[str, Any]]]:
return self._post(
Expand All @@ -22,6 +53,6 @@ def create(
params=None,
cast_to=dict,
stream_cls=Stream[dict],
stream=False,
stream=stream,
headers={},
)
18 changes: 9 additions & 9 deletions portkey_ai/api_resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __str__(self):


# Models for Chat Stream
class Delta(BaseModel):
class Delta(BaseModel, extra="allow"):
role: Optional[str] = None
content: Optional[str] = ""

Expand All @@ -272,7 +272,7 @@ def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default


class StreamChoice(BaseModel):
class StreamChoice(BaseModel, extra="allow"):
index: Optional[int] = None
delta: Union[Delta, Dict[Any, Any]] = {}
finish_reason: Optional[str] = None
Expand All @@ -287,7 +287,7 @@ def __getitem__(self, key):
return getattr(self, key, None)


class ChatCompletionChunk(BaseModel):
class ChatCompletionChunk(BaseModel, extra="allow"):
id: Optional[str] = None
object: Optional[str] = None
created: Optional[int] = None
Expand All @@ -305,7 +305,7 @@ def get(self, key: str, default: Optional[Any] = None):


# Models for Chat Non-stream
class ChatChoice(BaseModel):
class ChatChoice(BaseModel, extra="allow"):
index: Optional[int] = None
message: Optional[Message] = None
finish_reason: Optional[str] = None
Expand All @@ -320,7 +320,7 @@ def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default


class Usage(BaseModel):
class Usage(BaseModel, extra="allow"):
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
Expand All @@ -335,7 +335,7 @@ def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default


class ChatCompletions(BaseModel):
class ChatCompletions(BaseModel, extra="allow"):
id: Optional[str] = None
object: Optional[str] = None
created: Optional[int] = None
Expand All @@ -354,7 +354,7 @@ def get(self, key: str, default: Optional[Any] = None):


# Models for text completion Non-stream
class TextChoice(BaseModel):
class TextChoice(BaseModel, extra="allow"):
index: Optional[int] = None
text: Optional[str] = None
logprobs: Any
Expand All @@ -370,7 +370,7 @@ def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default


class TextCompletion(BaseModel):
class TextCompletion(BaseModel, extra="allow"):
id: Optional[str] = None
object: Optional[str] = None
created: Optional[int] = None
Expand All @@ -389,7 +389,7 @@ def get(self, key: str, default: Optional[Any] = None):


# Models for text completion stream
class TextCompletionChunk(BaseModel):
class TextCompletionChunk(BaseModel, extra="allow"):
id: Optional[str] = None
object: Optional[str] = None
created: Optional[int] = None
Expand Down

0 comments on commit 9755993

Please sign in to comment.