From 5f0a0ddd5ffcca8b63a4efc692571cfb87abd6bd Mon Sep 17 00:00:00 2001 From: visargD Date: Fri, 12 Jan 2024 18:52:43 +0530 Subject: [PATCH] feat: update classes with latest openai signature --- portkey_ai/api_resources/utils.py | 55 ++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/portkey_ai/api_resources/utils.py b/portkey_ai/api_resources/utils.py index a4825b5b..038821b0 100644 --- a/portkey_ai/api_resources/utils.py +++ b/portkey_ai/api_resources/utils.py @@ -116,15 +116,40 @@ class Options(BaseModel): json_body: Optional[Mapping[str, Any]] = None -class Message(TypedDict): +class FunctionCall(BaseModel): + name: str + arguments: str + +class ToolCall(BaseModel): + id: str + function: FunctionCall + type: str + +class DeltaToolCallFunction(BaseModel): + arguments: Optional[str] = None + name: Optional[str] = None + + +class DeltaToolCall(BaseModel): + index: int + id: Optional[str] = None + function: Optional[DeltaToolCallFunction] = None + type: Optional[str] = None + +class Message(BaseModel): role: str - content: str + content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None class Function(BaseModel): name: str description: str - parameters: str + parameters: Dict[str, object] + +class Tool(BaseModel): + function: Function + type: str class RetrySettings(TypedDict): @@ -149,7 +174,8 @@ class ModelParams(BaseModel): timeout: Union[float, None] = None functions: Optional[List[Function]] = None function_call: Optional[Union[None, str, Function]] = None - logprobs: Optional[int] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None echo: Optional[bool] = None stop: Optional[Union[str, List[str]]] = None presence_penalty: Optional[int] = None @@ -158,6 +184,9 @@ class ModelParams(BaseModel): logit_bias: Optional[Dict[str, int]] = None user: Optional[str] = None organization: Optional[str] = None + tool_choice: Optional[Union[None, str]] = None, + tools: Optional[List[Tool]] = None, + stream: Optional[bool] = False class OverrideParams(ModelParams, ConversationInput): @@ -261,6 +290,7 @@ def __str__(self): class Delta(BaseModel, extra="allow"): role: Optional[str] = None content: Optional[str] = "" + tool_calls: Optional[List[DeltaToolCall]] = None def __str__(self): return json.dumps(self.dict(), indent=4) @@ -303,12 +333,29 @@ def __getitem__(self, key): def get(self, key: str, default: Optional[Any] = None): return getattr(self, key, None) or default +class TopLogprob(BaseModel): + token: str + bytes: Optional[List[int]] = None + logprob: float + + +class ChatCompletionTokenLogprob(BaseModel): + token: str + bytes: Optional[List[int]] = None + logprob: float + top_logprobs: List[TopLogprob] + + +class ChoiceLogprobs(BaseModel): + content: Optional[List[ChatCompletionTokenLogprob]] = None + # Models for Chat Non-stream class ChatChoice(BaseModel, extra="allow"): index: Optional[int] = None message: Optional[Message] = None finish_reason: Optional[str] = None + logprobs: Optional[ChoiceLogprobs] = None def __str__(self): return json.dumps(self.dict(), indent=4)