Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/prompt response #110

Merged
merged 7 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions portkey_ai/api_resources/apis/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import warnings
from typing import Literal, Optional, Union, Mapping, Any, overload
from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient
from portkey_ai.api_resources.types.generation_type import (
PromptCompletion,
PromptCompletionChunk,
PromptRender,
)
from portkey_ai.api_resources.utils import (
retrieve_config,
GenericResponse,
Expand Down Expand Up @@ -88,14 +93,14 @@ def render(
self,
*,
prompt_id: str,
variables: Optional[Mapping[str, Any]] = None,
variables: Mapping[str, Any],
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptRender:
"""Prompt render Method"""
body = {
"variables": variables,
Expand All @@ -110,8 +115,8 @@ def render(
f"/prompts/{prompt_id}/render",
body=body,
params=None,
cast_to=GenericResponse,
stream_cls=Stream[GenericResponse],
cast_to=PromptRender,
stream_cls=Stream[PromptRender],
stream=False,
headers={},
)
Expand All @@ -128,29 +133,31 @@ async def render(
self,
*,
prompt_id: str,
variables: Optional[Mapping[str, Any]] = None,
variables: Mapping[str, Any],
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptRender:
"""Prompt render Method"""
body = {
"variables": variables,
"temperature": temperature,
"max_tokens": max_tokens,
"top_k": top_k,
"top_p": top_p,
"stream": stream,
**kwargs,
}
return await self._post(
f"/prompts/{prompt_id}/render",
body=body,
params=None,
cast_to=GenericResponse,
cast_to=PromptRender,
stream=False,
stream_cls=AsyncStream[GenericResponse],
stream_cls=AsyncStream[PromptRender],
headers={},
)

Expand All @@ -172,7 +179,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Stream[GenericResponse]:
) -> Stream[PromptCompletionChunk]:
...

@overload
Expand All @@ -188,7 +195,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptCompletion:
...

@overload
Expand All @@ -204,7 +211,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, Stream[GenericResponse]]:
) -> Union[PromptCompletion, Stream[PromptCompletionChunk]]:
...

def create(
Expand All @@ -219,7 +226,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, Stream[GenericResponse]]:
) -> Union[PromptCompletion, Stream[PromptCompletionChunk],]:
"""Prompt completions Method"""
if config is None:
config = retrieve_config()
Expand All @@ -236,8 +243,8 @@ def create(
f"/prompts/{prompt_id}/completions",
body=body,
params=None,
cast_to=GenericResponse,
stream_cls=Stream[GenericResponse],
cast_to=PromptCompletion,
stream_cls=Stream[PromptCompletionChunk],
stream=stream,
headers={},
)
Expand All @@ -260,7 +267,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> AsyncStream[GenericResponse]:
) -> AsyncStream[PromptCompletionChunk]:
...

@overload
Expand All @@ -276,7 +283,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptCompletion:
...

@overload
Expand All @@ -292,7 +299,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, AsyncStream[GenericResponse]]:
) -> Union[PromptCompletion, AsyncStream[PromptCompletionChunk]]:
...

async def create(
Expand All @@ -307,7 +314,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, AsyncStream[GenericResponse]]:
) -> Union[PromptCompletion, AsyncStream[PromptCompletionChunk]]:
"""Prompt completions Method"""
if config is None:
config = retrieve_config()
Expand All @@ -324,8 +331,8 @@ async def create(
f"/prompts/{prompt_id}/completions",
body=body,
params=None,
cast_to=GenericResponse,
stream_cls=AsyncStream[GenericResponse],
cast_to=PromptCompletion,
stream_cls=AsyncStream[PromptCompletionChunk],
stream=stream,
headers={},
)
23 changes: 21 additions & 2 deletions portkey_ai/api_resources/common_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import TypeVar, Union

import httpx

from portkey_ai.api_resources.types.generation_type import (
PromptCompletionChunk,
PromptRender,
)
from .streaming import Stream, AsyncStream
from .utils import GenericResponse
from .types.chat_complete_type import ChatCompletionChunk
Expand All @@ -9,13 +14,27 @@
StreamT = TypeVar(
"StreamT",
bound=Stream[
Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response]
Union[
ChatCompletionChunk,
TextCompletionChunk,
GenericResponse,
PromptCompletionChunk,
PromptRender,
httpx.Response,
]
],
)

AsyncStreamT = TypeVar(
"AsyncStreamT",
bound=AsyncStream[
Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response]
Union[
ChatCompletionChunk,
TextCompletionChunk,
GenericResponse,
PromptCompletionChunk,
PromptRender,
httpx.Response,
]
],
)
116 changes: 116 additions & 0 deletions portkey_ai/api_resources/types/generation_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
from typing import Dict, Optional, Union
import httpx

from portkey_ai.api_resources.types.chat_complete_type import (
ChatCompletionMessage,
Choice,
StreamChoice,
Usage,
)
from portkey_ai.api_resources.types.complete_type import Logprobs, TextChoice

from .utils import parse_headers
from typing import List, Any
from pydantic import BaseModel


class PromptCompletion(BaseModel):
id: Optional[str]
choices: List[Choice]
created: Optional[int]
model: Optional[str]
object: Optional[str]
system_fingerprint: Optional[str] = None
usage: Optional[Usage] = None
index: Optional[int] = None
text: Optional[str] = None
logprobs: Optional[Logprobs] = None
finish_reason: Optional[str] = None
_headers: Optional[httpx.Headers] = None

def __str__(self):
return json.dumps(self.dict(), indent=4)

def __getitem__(self, key):
return getattr(self, key, None)

def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default

def get_headers(self) -> Optional[Dict[str, str]]:
return parse_headers(self._headers)


class PromptCompletionChunk(BaseModel):
id: Optional[str] = None
object: Optional[str] = None
created: Optional[int] = None
model: Optional[str] = None
provider: Optional[str] = None
choices: Optional[Union[List[TextChoice], List[StreamChoice]]]

def __str__(self):
return json.dumps(self.dict(), indent=4)

def __getitem__(self, key):
return getattr(self, key, None)

def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default


FunctionParameters = Dict[str, object]


class Function(BaseModel):
name: Optional[str]
description: Optional[str] = None
parameters: Optional[FunctionParameters] = None


class Tool(BaseModel):
function: Function
type: Optional[str]


class PromptRenderData(BaseModel):
messages: Optional[List[ChatCompletionMessage]] = None
prompt: Optional[str] = None
model: Optional[str] = None
suffix: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop_sequences: Optional[List[str]] = None
timeout: Union[float, None] = None
functions: Optional[List[Function]] = None
function_call: Optional[Union[None, str, Function]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
echo: Optional[bool] = None
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[int] = None
frequency_penalty: Optional[int] = None
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, int]] = None
user: Optional[str] = None
organization: Optional[str] = None
tool_choice: Optional[Union[None, str]] = None
tools: Optional[List[Tool]] = None


class PromptRender(BaseModel):
success: Optional[bool] = True
data: PromptRenderData

def __str__(self):
return json.dumps(self.dict(), indent=4)

def __getitem__(self, key):
return getattr(self, key, None)

def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default
7 changes: 6 additions & 1 deletion portkey_ai/api_resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
TextCompletionChunk,
TextCompletion,
)
from portkey_ai.api_resources.types.generation_type import (
PromptCompletion,
PromptCompletionChunk,
PromptRender,
)
from .exceptions import (
APIStatusError,
BadRequestError,
Expand Down Expand Up @@ -56,7 +61,7 @@ class CacheType(str, Enum, metaclass=MetaEnum):

ResponseT = TypeVar(
"ResponseT",
bound="Union[ChatCompletionChunk, ChatCompletions, TextCompletion, TextCompletionChunk, GenericResponse, httpx.Response]", # noqa: E501
bound="Union[ChatCompletionChunk, ChatCompletions, TextCompletion, TextCompletionChunk, GenericResponse, PromptCompletion, PromptCompletionChunk, PromptRender, httpx.Response]", # noqa: E501
)


Expand Down
Loading