diff --git a/portkey_ai/__init__.py b/portkey_ai/__init__.py index dbd9acc0..b7ed91f3 100644 --- a/portkey_ai/__init__.py +++ b/portkey_ai/__init__.py @@ -12,16 +12,20 @@ PortkeyResponse, ChatCompletions, Completion, + AsyncCompletion, Params, Config, RetrySettings, ChatCompletion, + AsyncChatCompletion, ChatCompletionChunk, TextCompletion, TextCompletionChunk, createHeaders, Prompts, + AsyncPrompts, Portkey, + AsyncPortkey, ) from portkey_ai.version import VERSION from portkey_ai.api_resources.global_constants import ( @@ -49,9 +53,11 @@ "Message", "ChatCompletions", "Completion", + "AsyncCompletion", "Params", "RetrySettings", "ChatCompletion", + "AsyncChatCompletion", "ChatCompletionChunk", "TextCompletion", "TextCompletionChunk", @@ -61,5 +67,7 @@ "PORTKEY_GATEWAY_URL", "createHeaders", "Prompts", + "AsyncPrompts", "Portkey", + "AsyncPortkey", ] diff --git a/portkey_ai/api_resources/__init__.py b/portkey_ai/api_resources/__init__.py index 1f277ed0..0971a5b7 100644 --- a/portkey_ai/api_resources/__init__.py +++ b/portkey_ai/api_resources/__init__.py @@ -1,10 +1,15 @@ """""" from .apis import ( Completion, + AsyncCompletion, ChatCompletion, + AsyncChatCompletion, Generations, + AsyncGenerations, Prompts, + AsyncPrompts, Feedback, + AsyncFeedback, createHeaders, ) from .utils import ( @@ -25,7 +30,7 @@ TextCompletion, TextCompletionChunk, ) -from .client import Portkey +from .client import Portkey, AsyncPortkey from portkey_ai.version import VERSION @@ -42,16 +47,22 @@ "Message", "ChatCompletions", "Completion", + "AsyncCompletion", "Params", "Config", "RetrySettings", "ChatCompletion", + "AsyncChatCompletion", "ChatCompletionChunk", "TextCompletion", "TextCompletionChunk", "Generations", + "AsyncGenerations", "Prompts", + "AsyncPrompts", "Feedback", + "AsyncFeedback", "createHeaders", "Portkey", + "AsyncPortkey", ] diff --git a/portkey_ai/api_resources/apis/__init__.py b/portkey_ai/api_resources/apis/__init__.py index fa66a77f..83545951 100644 --- a/portkey_ai/api_resources/apis/__init__.py +++ b/portkey_ai/api_resources/apis/__init__.py @@ -1,18 +1,25 @@ -from .chat_complete import ChatCompletion -from .complete import Completion -from .generation import Generations, Prompts -from .feedback import Feedback +from .chat_complete import ChatCompletion, AsyncChatCompletion +from .complete import Completion, AsyncCompletion +from .generation import Generations, AsyncGenerations, Prompts, AsyncPrompts +from .feedback import Feedback, AsyncFeedback from .create_headers import createHeaders -from .post import Post -from .embeddings import Embeddings +from .post import Post, AsyncPost +from .embeddings import Embeddings, AsyncEmbeddings __all__ = [ "Completion", + "AsyncCompletion", "ChatCompletion", + "AsyncChatCompletion", "Generations", + "AsyncGenerations", "Feedback", + "AsyncFeedback", "Prompts", + "AsyncPrompts", "createHeaders", "Post", + "AsyncPost", "Embeddings", + "AsyncEmbeddings", ] diff --git a/portkey_ai/api_resources/apis/api_resource.py b/portkey_ai/api_resources/apis/api_resource.py index 3ccf9a26..205b5213 100644 --- a/portkey_ai/api_resources/apis/api_resource.py +++ b/portkey_ai/api_resources/apis/api_resource.py @@ -1,4 +1,5 @@ -from portkey_ai.api_resources.base_client import APIClient +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient +import asyncio class APIResource: @@ -17,3 +18,20 @@ def __init__(self, client: APIClient) -> None: def _post(self, *args, **kwargs): return self._client._post(*args, **kwargs) + + +class AsyncAPIResource: + _client: AsyncAPIClient + + def __init__(self, client: AsyncAPIClient) -> None: + self._client = client + # self._get = client.get + # self._patch = client.patch + # self._put = client.put + # self._delete = client.delete + + async def _post(self, *args, **kwargs): + return await self._client._post(*args, **kwargs) + + async def _sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) diff --git a/portkey_ai/api_resources/apis/chat_complete.py b/portkey_ai/api_resources/apis/chat_complete.py index 0f4ae2f6..4d0bc71b 100644 --- a/portkey_ai/api_resources/apis/chat_complete.py +++ b/portkey_ai/api_resources/apis/chat_complete.py @@ -2,7 +2,7 @@ import json from typing import Mapping, Optional, Union, overload, Literal, List -from portkey_ai.api_resources.base_client import APIClient +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient from portkey_ai.api_resources.utils import ( PortkeyApiPaths, Message, @@ -10,11 +10,11 @@ ChatCompletions, ) -from portkey_ai.api_resources.streaming import Stream -from portkey_ai.api_resources.apis.api_resource import APIResource +from portkey_ai.api_resources.streaming import AsyncStream, Stream +from portkey_ai.api_resources.apis.api_resource import APIResource, AsyncAPIResource -__all__ = ["ChatCompletion"] +__all__ = ["ChatCompletion", "AsyncChatCompletion"] class ChatCompletion(APIResource): @@ -25,6 +25,14 @@ def __init__(self, client: APIClient) -> None: self.completions = Completions(client) +class AsyncChatCompletion(AsyncAPIResource): + completions: AsyncCompletions + + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + self.completions = AsyncCompletions(client) + + class Completions(APIResource): def __init__(self, client: APIClient) -> None: super().__init__(client) @@ -107,3 +115,87 @@ def create( def _get_config_string(self, config: Union[Mapping, str]) -> str: return config if isinstance(config, str) else json.dumps(config) + + +class AsyncCompletions(AsyncAPIResource): + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + + @overload + async def create( + self, + *, + messages: Optional[List[Message]] = None, + config: Optional[Union[Mapping, str]] = None, + stream: Literal[True], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> AsyncStream[ChatCompletionChunk]: + ... + + @overload + async def create( + self, + *, + messages: Optional[List[Message]] = None, + config: Optional[Union[Mapping, str]] = None, + stream: Literal[False] = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> ChatCompletions: + ... + + @overload + async def create( + self, + *, + messages: Optional[List[Message]] = None, + config: Optional[Union[Mapping, str]] = None, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> Union[ChatCompletions, AsyncStream[ChatCompletionChunk]]: + ... + + async def create( + self, + *, + messages: Optional[List[Message]] = None, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> Union[ChatCompletions, AsyncStream[ChatCompletionChunk]]: + body = dict( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_k=top_k, + top_p=top_p, + stream=stream, + **kwargs, + ) + + return await self._post( + PortkeyApiPaths.CHAT_COMPLETE_API, + body=body, + params=None, + cast_to=ChatCompletions, + stream_cls=AsyncStream[ChatCompletionChunk], + stream=stream, + headers={}, + ) + + def _get_config_string(self, config: Union[Mapping, str]) -> str: + return config if isinstance(config, str) else json.dumps(config) diff --git a/portkey_ai/api_resources/apis/complete.py b/portkey_ai/api_resources/apis/complete.py index 8e8e503c..c87fbbd7 100644 --- a/portkey_ai/api_resources/apis/complete.py +++ b/portkey_ai/api_resources/apis/complete.py @@ -1,13 +1,13 @@ from typing import Optional, Union, overload, Literal -from portkey_ai.api_resources.base_client import APIClient +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient from portkey_ai.api_resources.utils import ( PortkeyApiPaths, TextCompletion, TextCompletionChunk, ) -from portkey_ai.api_resources.streaming import Stream -from portkey_ai.api_resources.apis.api_resource import APIResource +from portkey_ai.api_resources.streaming import AsyncStream, Stream +from portkey_ai.api_resources.apis.api_resource import APIResource, AsyncAPIResource class Completion(APIResource): @@ -85,3 +85,80 @@ def create( stream=stream, headers={}, ) + + +class AsyncCompletion(AsyncAPIResource): + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + + @overload + async def create( + self, + *, + prompt: Optional[str] = None, + stream: Literal[True], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> AsyncStream[TextCompletionChunk]: + ... + + @overload + async def create( + self, + *, + prompt: Optional[str] = None, + stream: Literal[False] = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> TextCompletion: + ... + + @overload + async def create( + self, + *, + prompt: Optional[str] = None, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> Union[TextCompletion, AsyncStream[TextCompletionChunk]]: + ... + + async def create( + self, + *, + prompt: Optional[str] = None, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> Union[TextCompletion, AsyncStream[TextCompletionChunk]]: + body = dict( + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + top_k=top_k, + top_p=top_p, + stream=stream, + **kwargs, + ) + return await self._post( + PortkeyApiPaths.TEXT_COMPLETE_API, + body=body, + params=None, + cast_to=TextCompletion, + stream_cls=AsyncStream[TextCompletionChunk], + stream=stream, + headers={}, + ) diff --git a/portkey_ai/api_resources/apis/embeddings.py b/portkey_ai/api_resources/apis/embeddings.py index f081664e..dd35ed22 100644 --- a/portkey_ai/api_resources/apis/embeddings.py +++ b/portkey_ai/api_resources/apis/embeddings.py @@ -1,6 +1,6 @@ from typing import Optional -from portkey_ai.api_resources.apis.api_resource import APIResource -from portkey_ai.api_resources.base_client import APIClient +from portkey_ai.api_resources.apis.api_resource import APIResource, AsyncAPIResource +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient from portkey_ai.api_resources.utils import PortkeyApiPaths, GenericResponse @@ -36,3 +36,37 @@ def create( stream=False, headers={}, ) + + +class AsyncEmbeddings(AsyncAPIResource): + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + + async def create( + self, + *, + input: str, + model: Optional[str] = None, + dimensions: Optional[int] = None, + encoding_format: Optional[str] = None, + user: Optional[str] = None, + **kwargs + ) -> GenericResponse: + body = dict( + input=input, + model=model, + user=user, + dimensions=dimensions, + encoding_format=encoding_format, + **kwargs, + ) + + return await self._post( + PortkeyApiPaths.EMBEDDING_API, + body=body, + params=None, + cast_to=GenericResponse, + stream_cls=None, + stream=False, + headers={}, + ) diff --git a/portkey_ai/api_resources/apis/feedback.py b/portkey_ai/api_resources/apis/feedback.py index 8038ad5e..1d14479e 100644 --- a/portkey_ai/api_resources/apis/feedback.py +++ b/portkey_ai/api_resources/apis/feedback.py @@ -1,7 +1,7 @@ from typing import Optional, Dict, Any, List -from portkey_ai.api_resources.apis.api_resource import APIResource -from portkey_ai.api_resources.base_client import APIClient -from portkey_ai.api_resources.streaming import Stream +from portkey_ai.api_resources.apis.api_resource import APIResource, AsyncAPIResource +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient +from portkey_ai.api_resources.streaming import AsyncStream, Stream from portkey_ai.api_resources.utils import GenericResponse, PortkeyApiPaths @@ -39,3 +39,39 @@ def bulk_create(self, *, feedbacks: List[Dict[str, Any]]) -> GenericResponse: stream=False, headers={}, ) + + +class AsyncFeedback(AsyncAPIResource): + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + + async def create( + self, + *, + trace_id: Optional[str] = None, + value: Optional[int] = None, + weight: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> GenericResponse: + body = dict(trace_id=trace_id, value=value, weight=weight, metadata=metadata) + return await self._post( + PortkeyApiPaths.FEEDBACK_API, + body=body, + params=None, + cast_to=GenericResponse, + stream_cls=AsyncStream[GenericResponse], + stream=False, + headers={}, + ) + + async def bulk_create(self, *, feedbacks: List[Dict[str, Any]]) -> GenericResponse: + body = feedbacks + return await self._post( + PortkeyApiPaths.FEEDBACK_API, + body=body, + params=None, + cast_to=GenericResponse, + stream_cls=AsyncStream[GenericResponse], + stream=False, + headers={}, + ) diff --git a/portkey_ai/api_resources/apis/generation.py b/portkey_ai/api_resources/apis/generation.py index ac365284..627a54c1 100644 --- a/portkey_ai/api_resources/apis/generation.py +++ b/portkey_ai/api_resources/apis/generation.py @@ -1,14 +1,14 @@ from __future__ import annotations import warnings from typing import Literal, Optional, Union, Mapping, Any, overload -from portkey_ai.api_resources.base_client import APIClient +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient from portkey_ai.api_resources.utils import ( retrieve_config, GenericResponse, ) -from portkey_ai.api_resources.streaming import Stream -from portkey_ai.api_resources.apis.api_resource import APIResource +from portkey_ai.api_resources.streaming import AsyncStream, Stream +from portkey_ai.api_resources.apis.api_resource import APIResource, AsyncAPIResource class Generations(APIResource): @@ -44,6 +44,39 @@ def create( return response +class AsyncGenerations(AsyncAPIResource): + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + + async def create( + self, + *, + prompt_id: str, + config: Optional[Union[Mapping, str]] = None, + variables: Optional[Mapping[str, Any]] = None, + ) -> Union[GenericResponse, AsyncStream[GenericResponse]]: + warning_message = "This API has been deprecated. Please use the Prompt API for the saved prompt." # noqa: E501 + warnings.warn( + warning_message, + DeprecationWarning, + stacklevel=2, + ) + if config is None: + config = retrieve_config() + body = {"variables": variables} + response = await self._post( + f"/v1/prompts/{prompt_id}/generate", + body=body, + mode=None, + params=None, + cast_to=GenericResponse, + stream_cls=AsyncStream[GenericResponse], + stream=False, + ) + response["warning"] = warning_message + return response + + class Prompts(APIResource): completions: Completions @@ -52,6 +85,14 @@ def __init__(self, client: APIClient) -> None: self.completions = Completions(client) +class AsyncPrompts(AsyncAPIResource): + completions: AsyncCompletions + + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + self.completions = AsyncCompletions(client) + + class Completions(APIResource): def __init__(self, client: APIClient) -> None: super().__init__(client) @@ -138,3 +179,91 @@ def create( stream=stream, headers={}, ) + + +class AsyncCompletions(AsyncAPIResource): + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + + @overload + async def create( + self, + *, + prompt_id: str, + variables: Optional[Mapping[str, Any]] = None, + config: Optional[Union[Mapping, str]] = None, + stream: Literal[True], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> AsyncStream[GenericResponse]: + ... + + @overload + async def create( + self, + *, + prompt_id: str, + variables: Optional[Mapping[str, Any]] = None, + config: Optional[Union[Mapping, str]] = None, + stream: Literal[False] = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> GenericResponse: + ... + + @overload + async def create( + self, + *, + prompt_id: str, + variables: Optional[Mapping[str, Any]] = None, + config: Optional[Union[Mapping, str]] = None, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> Union[GenericResponse, AsyncStream[GenericResponse]]: + ... + + async def create( + self, + *, + prompt_id: str, + variables: Optional[Mapping[str, Any]] = None, + config: Optional[Union[Mapping, str]] = None, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, + ) -> Union[GenericResponse, AsyncStream[GenericResponse]]: + """Prompt completions Method""" + if config is None: + config = retrieve_config() + body = { + "variables": variables, + "temperature": temperature, + "max_tokens": max_tokens, + "top_k": top_k, + "top_p": top_p, + "stream": stream, + **kwargs, + } + return await self._post( + f"/prompts/{prompt_id}/completions", + body=body, + params=None, + cast_to=GenericResponse, + stream_cls=AsyncStream[GenericResponse], + stream=stream, + headers={}, + ) diff --git a/portkey_ai/api_resources/apis/post.py b/portkey_ai/api_resources/apis/post.py index 00901e01..54f12f66 100644 --- a/portkey_ai/api_resources/apis/post.py +++ b/portkey_ai/api_resources/apis/post.py @@ -1,9 +1,9 @@ from typing import Union, overload, Literal -from portkey_ai.api_resources.base_client import APIClient +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient -from portkey_ai.api_resources.streaming import Stream -from portkey_ai.api_resources.apis.api_resource import APIResource +from portkey_ai.api_resources.streaming import Stream, AsyncStream +from portkey_ai.api_resources.apis.api_resource import APIResource, AsyncAPIResource from portkey_ai.api_resources.utils import GenericResponse @@ -57,3 +57,55 @@ def create( stream=stream, headers={}, ) + + +class AsyncPost(AsyncAPIResource): + def __init__(self, client: AsyncAPIClient) -> None: + super().__init__(client) + + @overload + async def create( + self, + *, + url: str, + stream: Literal[True], + **kwargs, + ) -> AsyncStream[GenericResponse]: + ... + + @overload + async def create( + self, + *, + url: str, + stream: Literal[False] = False, + **kwargs, + ) -> GenericResponse: + ... + + @overload + async def create( + self, + *, + url: str, + stream: bool = False, + **kwargs, + ) -> Union[GenericResponse, AsyncStream[GenericResponse]]: + ... + + async def create( + self, + *, + url: str, + stream: bool = False, + **kwargs, + ) -> Union[GenericResponse, AsyncStream[GenericResponse]]: + return await self._post( + url, + body=kwargs, + params=None, + cast_to=GenericResponse, + stream_cls=AsyncStream[GenericResponse], + stream=stream, + headers={}, + ) diff --git a/portkey_ai/api_resources/base_client.py b/portkey_ai/api_resources/base_client.py index 0403ebc2..60628ffe 100644 --- a/portkey_ai/api_resources/base_client.py +++ b/portkey_ai/api_resources/base_client.py @@ -1,4 +1,5 @@ from __future__ import annotations +import asyncio import json from types import TracebackType @@ -27,8 +28,8 @@ ) from portkey_ai.version import VERSION from .utils import ResponseT, make_status_error, default_api_key, default_base_url -from .common_types import StreamT -from .streaming import Stream +from .common_types import StreamT, AsyncStreamT +from .streaming import Stream, AsyncStream class MissingStreamClassError(TypeError): @@ -376,3 +377,351 @@ def _make_status_error_from_response( err_msg = err_text or f"Error code: {response.status_code}" return make_status_error(err_msg, body=body, request=request, response=response) + + +class AsyncHttpxClientWrapper(httpx.AsyncClient): + def __del__(self) -> None: + try: + asyncio.get_running_loop().create_task(self.aclose()) + except Exception: + pass + + +class AsyncAPIClient: + _client: httpx.AsyncClient + _default_stream_cls: Union[type[AsyncStream[Any]], None] = None + + def __init__( + self, + *, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + virtual_key: Optional[str] = None, + config: Optional[Union[Mapping, str]] = None, + provider: Optional[str] = None, + trace_id: Optional[str] = None, + metadata: Optional[str] = None, + **kwargs, + ) -> None: + self.api_key = api_key or default_api_key() + self.base_url = base_url or default_base_url() + self.virtual_key = virtual_key + self.config = config + self.provider = provider + self.trace_id = trace_id + self.metadata = metadata + self.kwargs = kwargs + + self.custom_headers = createHeaders( + virtual_key=virtual_key, + config=config, + provider=provider, + trace_id=trace_id, + metadata=metadata, + **kwargs, + ) + + self._client = AsyncHttpxClientWrapper( + base_url=self.base_url, + headers={ + "Accept": "application/json", + }, + ) + + self.response_headers: httpx.Headers | None = None + + def _serialize_header_values( + self, headers: Optional[Mapping[str, Any]] + ) -> Dict[str, str]: + if headers is None: + return {} + return { + f"{PORTKEY_HEADER_PREFIX}{k}": json.dumps(v) + if isinstance(v, (dict, list)) + else str(v) + for k, v in headers.items() + } + + @property + def custom_auth(self) -> Optional[httpx.Auth]: + return None + + @overload + async def _post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Mapping[str, Any], + stream: Literal[False], + stream_cls: type[AsyncStreamT], + params: Mapping[str, str], + headers: Mapping[str, str], + ) -> ResponseT: + ... + + @overload + async def _post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Mapping[str, Any], + stream: Literal[True], + stream_cls: type[AsyncStreamT], + params: Mapping[str, str], + headers: Mapping[str, str], + ) -> AsyncStreamT: + ... + + @overload + async def _post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Mapping[str, Any], + stream: bool, + stream_cls: type[AsyncStreamT], + params: Mapping[str, str], + headers: Mapping[str, str], + ) -> Union[ResponseT, AsyncStreamT]: + ... + + async def _post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Mapping[str, Any], + stream: bool, + stream_cls: type[AsyncStreamT], + params: Mapping[str, str], + headers: Mapping[str, str], + ) -> Union[ResponseT, AsyncStreamT]: + if path.endswith("/generate"): + opts = await self._construct_generate_options( + method="post", + url=path, + body=body, + stream=stream, + params=params, + headers=headers, + ) + else: + opts = await self._construct( + method="post", + url=path, + body=body, + stream=stream, + params=params, + headers=headers, + ) + + res = await self._request( + options=opts, + stream=stream, + cast_to=cast_to, + stream_cls=stream_cls, + ) + return res + + async def _construct_generate_options( + self, + *, + method: str, + url: str, + body: Any, + stream: bool, + params: Mapping[str, str], + headers: Mapping[str, str], + ) -> Options: + opts = Options.construct() + opts.method = method + opts.url = url + json_body = body + opts.json_body = remove_empty_values(json_body) + opts.headers = remove_empty_values(headers) + return opts + + async def _construct( + self, + *, + method: str, + url: str, + body: Mapping[str, Any], + stream: bool, + params: Mapping[str, str], + headers: Mapping[str, str], + ) -> Options: + opts = Options.construct() + opts.method = method + opts.url = url + opts.json_body = remove_empty_values(body) + opts.headers = remove_empty_values(headers) + return opts + + @property + def _default_headers(self) -> Mapping[str, str]: + return { + "Content-Type": "application/json", + f"{PORTKEY_HEADER_PREFIX}api-key": self.api_key, + f"{PORTKEY_HEADER_PREFIX}package-version": f"portkey-{VERSION}", + f"{PORTKEY_HEADER_PREFIX}runtime": platform.python_implementation(), + f"{PORTKEY_HEADER_PREFIX}runtime-version": platform.python_version(), + } + + def _build_headers(self, options: Options) -> httpx.Headers: + option_headers = options.headers or {} + headers_dict = self._merge_mappings( + self._default_headers, option_headers, self.custom_headers + ) + headers = httpx.Headers(headers_dict) + return headers + + def _merge_mappings( + self, + *args, + ) -> Dict[str, Any]: + """Merge two mappings of the given type + In cases with duplicate keys the second mapping takes precedence. + """ + mapped_headers = {} + for i in args: + mapped_headers.update(i) + return mapped_headers + + def is_closed(self) -> bool: + return self._client.is_closed + + async def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + await self._client.aclose() + + async def __aenter__(self: Any) -> Any: + return self + + async def __aexit__( + self, + exc_type: Optional[BaseException], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + async def _build_request(self, options: Options) -> httpx.Request: + headers = self._build_headers(options) + params = options.params + json_body = options.json_body + request = self._client.build_request( + method=options.method, + url=options.url, + headers=headers, + params=params, + json=json_body, + timeout=options.timeout, + ) + return request + + @overload + async def _request( + self, + *, + options: Options, + stream: Literal[False], + cast_to: Type[ResponseT], + stream_cls: Type[AsyncStreamT], + ) -> ResponseT: + ... + + @overload + async def _request( + self, + *, + options: Options, + stream: Literal[True], + cast_to: Type[ResponseT], + stream_cls: Type[AsyncStreamT], + ) -> AsyncStreamT: + ... + + @overload + async def _request( + self, + *, + options: Options, + stream: bool, + cast_to: Type[ResponseT], + stream_cls: Type[AsyncStreamT], + ) -> Union[ResponseT, AsyncStreamT]: + ... + + async def _request( + self, + *, + options: Options, + stream: bool, + cast_to: Type[ResponseT], + stream_cls: Type[AsyncStreamT], + ) -> Union[ResponseT, AsyncStreamT]: + request = await self._build_request(options) + try: + res = await self._client.send(request, auth=self.custom_auth, stream=stream) + res.raise_for_status() + except httpx.HTTPStatusError as err: # 4xx and 5xx errors + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + await err.response.aread() + raise self._make_status_error_from_response(request, err.response) from None + except httpx.TimeoutException as err: + raise APITimeoutError(request=request) from err + except Exception as err: + raise APIConnectionError(request=request) from err + + self.response_headers = res.headers + if stream or res.headers["content-type"] == "text/event-stream": + if stream_cls is None: + raise MissingStreamClassError() + stream_response = stream_cls( + response=res, cast_to=self._extract_stream_chunk_type(stream_cls) + ) + return stream_response + + response = ( + cast( + ResponseT, + cast_to(**res.json()), + ) + if not isinstance(cast_to, httpx.Response) + else cast(ResponseT, res) + ) + response._headers = res.headers # type: ignore + return response + + def _extract_stream_chunk_type(self, stream_cls: Type) -> type: + args = get_args(stream_cls) + if not args: + raise TypeError( + f"Expected stream_cls to have been given a generic type argument, e.g. \ + Stream[Foo] but received {stream_cls}", + ) + return cast(type, args[0]) + + def _make_status_error_from_response( + self, + request: httpx.Request, + response: httpx.Response, + ) -> APIStatusError: + err_text = response.text.strip() + body = err_text + + try: + body = json.loads(err_text)["error"]["message"] + err_msg = f"Error code: {response.status_code} - {body}" + except Exception: + err_msg = err_text or f"Error code: {response.status_code}" + + return make_status_error(err_msg, body=body, request=request, response=response) diff --git a/portkey_ai/api_resources/client.py b/portkey_ai/api_resources/client.py index 6d311d92..d05be2b8 100644 --- a/portkey_ai/api_resources/client.py +++ b/portkey_ai/api_resources/client.py @@ -2,7 +2,7 @@ from typing import Mapping, Optional, Union from portkey_ai.api_resources import apis -from portkey_ai.api_resources.base_client import APIClient +from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient class Portkey(APIClient): @@ -70,3 +70,70 @@ def post(self, url: str, **kwargs): return apis.Post(self).create(url=url, **kwargs) with_options = copy + + +class AsyncPortkey(AsyncAPIClient): + completions: apis.AsyncCompletion + chat: apis.AsyncChatCompletion + generations: apis.AsyncGenerations + prompts: apis.AsyncPrompts + embeddings: apis.AsyncEmbeddings + + def __init__( + self, + *, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + virtual_key: Optional[str] = None, + config: Optional[Union[Mapping, str]] = None, + provider: Optional[str] = None, + trace_id: Optional[str] = None, + metadata: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__( + api_key=api_key, + base_url=base_url, + virtual_key=virtual_key, + config=config, + provider=provider, + trace_id=trace_id, + metadata=metadata, + **kwargs, + ) + + self.completions = apis.AsyncCompletion(self) + self.chat = apis.AsyncChatCompletion(self) + self.generations = apis.AsyncGenerations(self) + self.prompts = apis.AsyncPrompts(self) + self.embeddings = apis.AsyncEmbeddings(self) + self.feedback = apis.AsyncFeedback(self) + + def copy( + self, + *, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + virtual_key: Optional[str] = None, + config: Optional[Union[Mapping, str]] = None, + provider: Optional[str] = None, + trace_id: Optional[str] = None, + metadata: Optional[str] = None, + **kwargs, + ) -> AsyncPortkey: + return self.__class__( + api_key=api_key or self.api_key, + base_url=base_url or self.base_url, + virtual_key=virtual_key or self.virtual_key, + config=config or self.config, + provider=provider or self.provider, + trace_id=trace_id or self.trace_id, + metadata=metadata or self.metadata, + **self.kwargs, + **kwargs, + ) + + async def post(self, url: str, **kwargs): + return await apis.AsyncPost(self).create(url=url, **kwargs) + + with_options = copy diff --git a/portkey_ai/api_resources/common_types.py b/portkey_ai/api_resources/common_types.py index 7c0e5652..0a2af85f 100644 --- a/portkey_ai/api_resources/common_types.py +++ b/portkey_ai/api_resources/common_types.py @@ -1,7 +1,7 @@ from typing import TypeVar, Union import httpx -from .streaming import Stream +from .streaming import Stream, AsyncStream from .utils import ChatCompletionChunk, TextCompletionChunk, GenericResponse StreamT = TypeVar( @@ -10,3 +10,10 @@ Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response] ], ) + +AsyncStreamT = TypeVar( + "AsyncStreamT", + bound=AsyncStream[ + Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response] + ], +) diff --git a/portkey_ai/api_resources/streaming.py b/portkey_ai/api_resources/streaming.py index 1e551a98..cb821c51 100644 --- a/portkey_ai/api_resources/streaming.py +++ b/portkey_ai/api_resources/streaming.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any, Iterator, Generic, cast, Union, Type +from typing import Any, Iterator, AsyncIterator, Generic, cast, Union, Type import httpx @@ -74,6 +74,17 @@ def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]: if sse is not None: yield sse + async def aiter( + self, iterator: AsyncIterator[str] + ) -> AsyncIterator[ServerSentEvent]: + """Given an async iterator that yields lines, + iterate over it & yield every event encountered""" + async for line in iterator: + line = line.rstrip("\n") + sse = self.decode(line) + if sse is not None: + yield sse + def decode(self, line: str) -> Union[ServerSentEvent, None]: # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 @@ -177,3 +188,56 @@ def __stream__(self) -> Iterator[ResponseT]: response=self.response, request=self.response.request, ) + + +class AsyncStream(Generic[ResponseT]): + """Provides the core interface to iterate over a asynchronous stream response.""" + + response: httpx.Response + + def __init__(self, *, response: httpx.Response, cast_to: Type[ResponseT]) -> None: + self._cast_to = cast_to + self.response = response + self._decoder = SSEDecoder() + self._iterator = self.__stream__() + + async def __anext__(self) -> ResponseT: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[ResponseT]: + async for item in self._iterator: + yield item + + async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: + async for sse in self._decoder.aiter(self.response.aiter_lines()): + yield sse + + async def __stream__(self) -> AsyncIterator[ResponseT]: + response = self.response + + async for sse in self._iter_events(): + if sse.data.startswith("[DONE]"): + break + if sse.event is None: + yield cast(ResponseT, self._cast_to(**sse.json())) if not isinstance( + self._cast_to, httpx.Response + ) else cast(ResponseT, sse) + + if sse.event == "ping": + continue + + if sse.event == "error": + body = sse.data + + try: + body = sse.json() + err_msg = f"{body}" + except Exception: + err_msg = sse.data or f"Error code: {response.status_code}" + + raise make_status_error( + err_msg, + body=body, + response=self.response, + request=self.response.request, + ) diff --git a/setup.cfg b/setup.cfg index 2cb536f5..00798242 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,6 +41,7 @@ dev = pytest==7.4.2 python-dotenv==1.0.0 ruff==0.0.292 + pytest-asyncio==0.23.5 [mypy] ignore_missing_imports = true diff --git a/tests/configs/chat_completions/loadbalance_and_fallback/anthropic_n_openai.json b/tests/configs/chat_completions/loadbalance_and_fallback/anthropic_n_openai.json index e3d8a401..4f2b3396 100644 --- a/tests/configs/chat_completions/loadbalance_and_fallback/anthropic_n_openai.json +++ b/tests/configs/chat_completions/loadbalance_and_fallback/anthropic_n_openai.json @@ -5,7 +5,7 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -17,10 +17,10 @@ }, "targets": [ { - "virtual_key": "anthropic-419f08" + "virtual_key": "vdanthropic-87ad2c" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/chat_completions/loadbalance_and_fallback/anyscale_n_openai.json b/tests/configs/chat_completions/loadbalance_and_fallback/anyscale_n_openai.json index 8642515e..e23100fb 100644 --- a/tests/configs/chat_completions/loadbalance_and_fallback/anyscale_n_openai.json +++ b/tests/configs/chat_completions/loadbalance_and_fallback/anyscale_n_openai.json @@ -5,7 +5,7 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -17,10 +17,10 @@ }, "targets": [ { - "virtual_key": "anyscale-c24b93" + "virtual_key": "vdanyscale-354c5b" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/chat_completions/loadbalance_and_fallback/azure_n_openai.json b/tests/configs/chat_completions/loadbalance_and_fallback/azure_n_openai.json index 241183c5..dc7eaf32 100644 --- a/tests/configs/chat_completions/loadbalance_and_fallback/azure_n_openai.json +++ b/tests/configs/chat_completions/loadbalance_and_fallback/azure_n_openai.json @@ -5,7 +5,7 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -20,7 +20,7 @@ "virtual_key": "azure-api-key-993da0" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/chat_completions/loadbalance_and_fallback/cohere_n_openai.json b/tests/configs/chat_completions/loadbalance_and_fallback/cohere_n_openai.json index 79212141..91216c2b 100644 --- a/tests/configs/chat_completions/loadbalance_and_fallback/cohere_n_openai.json +++ b/tests/configs/chat_completions/loadbalance_and_fallback/cohere_n_openai.json @@ -4,7 +4,7 @@ }, "targets": [ { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -16,10 +16,10 @@ }, "targets": [ { - "virtual_key": "cohere-api-key-fffe27" + "virtual_key": "vdcohere-1402b0" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/chat_completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json b/tests/configs/chat_completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json index 14b58beb..7e8392d3 100644 --- a/tests/configs/chat_completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json +++ b/tests/configs/chat_completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json @@ -5,11 +5,11 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "provider": "anthropic", - "virtual_key": "anthropic-419f08" + "virtual_key": "vdanthropic-87ad2c" } ] } \ No newline at end of file diff --git a/tests/configs/chat_completions/single_provider/single_provider.json b/tests/configs/chat_completions/single_provider/single_provider.json index e15c4916..713c9374 100644 --- a/tests/configs/chat_completions/single_provider/single_provider.json +++ b/tests/configs/chat_completions/single_provider/single_provider.json @@ -1,4 +1,4 @@ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } \ No newline at end of file diff --git a/tests/configs/chat_completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json b/tests/configs/chat_completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json index 91700f9d..12a13b21 100644 --- a/tests/configs/chat_completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json +++ b/tests/configs/chat_completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json @@ -1,5 +1,5 @@ { - "virtual_key": "open-ai-apikey-3368e0", + "virtual_key": "vdkey-ff9e7c", "cache": { "mode": "semantic", "max_age": 60 diff --git a/tests/configs/chat_completions/single_with_basic_config/single_with_basic_config.json b/tests/configs/chat_completions/single_with_basic_config/single_with_basic_config.json index 6703a2fd..5f27cb16 100644 --- a/tests/configs/chat_completions/single_with_basic_config/single_with_basic_config.json +++ b/tests/configs/chat_completions/single_with_basic_config/single_with_basic_config.json @@ -1,3 +1,3 @@ { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } \ No newline at end of file diff --git a/tests/configs/completions/loadbalance_and_fallback/anthropic_n_openai.json b/tests/configs/completions/loadbalance_and_fallback/anthropic_n_openai.json index e3d8a401..4f2b3396 100644 --- a/tests/configs/completions/loadbalance_and_fallback/anthropic_n_openai.json +++ b/tests/configs/completions/loadbalance_and_fallback/anthropic_n_openai.json @@ -5,7 +5,7 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -17,10 +17,10 @@ }, "targets": [ { - "virtual_key": "anthropic-419f08" + "virtual_key": "vdanthropic-87ad2c" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/completions/loadbalance_and_fallback/anyscale_n_openai.json b/tests/configs/completions/loadbalance_and_fallback/anyscale_n_openai.json index 8642515e..e23100fb 100644 --- a/tests/configs/completions/loadbalance_and_fallback/anyscale_n_openai.json +++ b/tests/configs/completions/loadbalance_and_fallback/anyscale_n_openai.json @@ -5,7 +5,7 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -17,10 +17,10 @@ }, "targets": [ { - "virtual_key": "anyscale-c24b93" + "virtual_key": "vdanyscale-354c5b" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/completions/loadbalance_and_fallback/azure_n_openai.json b/tests/configs/completions/loadbalance_and_fallback/azure_n_openai.json index 241183c5..dc7eaf32 100644 --- a/tests/configs/completions/loadbalance_and_fallback/azure_n_openai.json +++ b/tests/configs/completions/loadbalance_and_fallback/azure_n_openai.json @@ -5,7 +5,7 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -20,7 +20,7 @@ "virtual_key": "azure-api-key-993da0" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/completions/loadbalance_and_fallback/cohere_n_openai.json b/tests/configs/completions/loadbalance_and_fallback/cohere_n_openai.json index 79212141..91216c2b 100644 --- a/tests/configs/completions/loadbalance_and_fallback/cohere_n_openai.json +++ b/tests/configs/completions/loadbalance_and_fallback/cohere_n_openai.json @@ -4,7 +4,7 @@ }, "targets": [ { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "strategy": { @@ -16,10 +16,10 @@ }, "targets": [ { - "virtual_key": "cohere-api-key-fffe27" + "virtual_key": "vdcohere-1402b0" }, { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } ] } diff --git a/tests/configs/completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json b/tests/configs/completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json index 14b58beb..7e8392d3 100644 --- a/tests/configs/completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json +++ b/tests/configs/completions/loadbalance_with_two_apikeys/loadbalance_with_two_apikeys.json @@ -5,11 +5,11 @@ "targets": [ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" }, { "provider": "anthropic", - "virtual_key": "anthropic-419f08" + "virtual_key": "vdanthropic-87ad2c" } ] } \ No newline at end of file diff --git a/tests/configs/completions/single_provider/single_provider.json b/tests/configs/completions/single_provider/single_provider.json index e15c4916..713c9374 100644 --- a/tests/configs/completions/single_provider/single_provider.json +++ b/tests/configs/completions/single_provider/single_provider.json @@ -1,4 +1,4 @@ { "provider": "openai", - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } \ No newline at end of file diff --git a/tests/configs/completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json b/tests/configs/completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json index 91700f9d..12a13b21 100644 --- a/tests/configs/completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json +++ b/tests/configs/completions/single_provider_with_vk_retry_cache/single_provider_with_vk_retry_cache.json @@ -1,5 +1,5 @@ { - "virtual_key": "open-ai-apikey-3368e0", + "virtual_key": "vdkey-ff9e7c", "cache": { "mode": "semantic", "max_age": 60 diff --git a/tests/configs/completions/single_with_basic_config/single_with_basic_config.json b/tests/configs/completions/single_with_basic_config/single_with_basic_config.json index 6703a2fd..5f27cb16 100644 --- a/tests/configs/completions/single_with_basic_config/single_with_basic_config.json +++ b/tests/configs/completions/single_with_basic_config/single_with_basic_config.json @@ -1,3 +1,3 @@ { - "virtual_key": "open-ai-apikey-3368e0" + "virtual_key": "vdkey-ff9e7c" } \ No newline at end of file diff --git a/tests/models.json b/tests/models.json index cc50e242..c1f99861 100644 --- a/tests/models.json +++ b/tests/models.json @@ -16,15 +16,7 @@ "gpt-4-0613" ], "text": [ - "gpt-3.5-turbo-instruct", - "text-davinci-003", - "text-davinci-002", - "text-curie-001", - "text-babbage-001", - "text-ada-001", - "babbage-002", - "davinci-002", - "text-davinci-001" + "gpt-3.5-turbo-instruct" ] }, "anyscale": { diff --git a/tests/test_async_chat_complete.py b/tests/test_async_chat_complete.py new file mode 100644 index 00000000..7ed67393 --- /dev/null +++ b/tests/test_async_chat_complete.py @@ -0,0 +1,474 @@ +from __future__ import annotations +import inspect + +import os +from os import walk +from typing import Any, Dict, List +import pytest +from uuid import uuid4 +from portkey_ai import AsyncPortkey +from time import sleep +from dotenv import load_dotenv +from .utils import read_json_file + + +load_dotenv(override=True) +base_url = os.environ.get("PORTKEY_BASE_URL") +api_key = os.environ.get("PORTKEY_API_KEY") +virtual_api_key = os.environ.get("OPENAI_VIRTUAL_KEY") +CONFIGS_PATH = "./tests/configs/chat_completions" + + +def get_configs(folder_path) -> List[Dict[str, Any]]: + config_files = [] + for dirpath, _, file_names in walk(folder_path): + for f in file_names: + config_files.append(read_json_file(os.path.join(dirpath, f))) + + return config_files + + +class TestChatCompletions: + client = AsyncPortkey + parametrize = pytest.mark.parametrize("client", [client], ids=["strict"]) + models = read_json_file("./tests/models.json") + + def get_metadata(self): + return { + "case": "testing", + "function": inspect.currentframe().f_back.f_code.co_name, + "random_id": str(uuid4()), + } + + # -------------------------- + # Test-1 + t1_params = [] + t = [] + for k, v in models.items(): + for i in v["chat"]: + t.append((client, k, os.environ.get(v["env_variable"]), i)) + + t1_params.extend(t) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, provider, auth, model", t1_params) + async def test_method_single_with_vk_and_provider( + self, client: Any, provider: str, auth: str, model + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + provider=f"{provider}", + Authorization=f"Bearer {auth}", + trace_id=str(uuid4()), + metadata=self.get_metadata(), + ) + + await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model=model, + max_tokens=245, + ) + + # -------------------------- + # Test -2 + t2_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_with_basic_config"): + t2_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t2_params) + async def test_method_single_with_basic_config( + self, client: Any, config: Dict + ) -> None: + """ + Test the creation of a chat completion with a virtual key using the specified + Portkey client. + + This test method performs the following steps: + 1. Creates a Portkey client instance with the provided base URL, API key, trace + ID, and configuration loaded from the 'single_provider_with_virtualkey.json' + file. + 2. Calls the Portkey client's chat.completions.create method to generate a + completion. + 3. Prints the choices from the completion. + + Args: + client (Portkey): The Portkey client instance used for the test. + + Raises: + Any exceptions raised during the test. + + Note: + - Ensure that the 'single_provider_with_virtualkey.json' file exists and + contains valid configuration data. + - Modify the 'model' parameter and the 'messages' content as needed for your + use case. + """ + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + ) + + # print(completion.choices) + # assert("True", "True") + + # assert_matches_type(TextCompletion, completion, path=["response"]) + + # -------------------------- + # Test-3 + t3_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider_with_vk_retry_cache"): + t3_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t3_params) + async def test_method_single_provider_with_vk_retry_cache( + self, client: Any, config: Dict + ) -> None: + # 1. Make a new cache the cache + # 2. Make a cache hit and see if the response contains the data. + random_id = str(uuid4()) + metadata = self.get_metadata() + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + ) + # Sleeping for the cache to reflect across the workers. The cache has an + # eventual consistency and not immediate consistency. + sleep(20) + portkey_2 = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + portkey_2.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + ) + + # -------------------------- + # Test-4 + t4_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_with_two_apikeys"): + t4_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t4_params) + async def test_method_loadbalance_with_two_apikeys( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + # virtual_key=virtual_api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + completion = await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], max_tokens=245 + ) + + print(completion.choices) + + # -------------------------- + # Test-5 + t5_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_and_fallback"): + t5_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t5_params) + async def test_method_loadbalance_and_fallback( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Say this is just a loadbalance and fallback test test", + } + ], + ) + + print(completion.choices) + + # -------------------------- + # Test-6 + t6_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider"): + t6_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t6_params) + async def test_method_single_provider(self, client: Any, config: Dict) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + ) + + print(completion.choices) + + +class TestChatCompletionsStreaming: + client = AsyncPortkey + parametrize = pytest.mark.parametrize("client", [client], ids=["strict"]) + models = read_json_file("./tests/models.json") + + def get_metadata(self): + return { + "case": "testing", + "function": inspect.currentframe().f_back.f_code.co_name, + "random_id": str(uuid4()), + } + + # -------------------------- + # Test-1 + t1_params = [] + t = [] + for k, v in models.items(): + for i in v["chat"]: + t.append((client, k, os.environ.get(v["env_variable"]), i)) + + t1_params.extend(t) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, provider, auth, model", t1_params) + async def test_method_single_with_vk_and_provider( + self, client: Any, provider: str, auth: str, model + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + provider=f"{provider}", + Authorization=f"Bearer {auth}", + trace_id=str(uuid4()), + metadata=self.get_metadata(), + ) + + await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model=model, + max_tokens=245, + stream=True, + ) + + # -------------------------- + # Test -2 + t2_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_with_basic_config"): + t2_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t2_params) + async def test_method_single_with_basic_config( + self, client: Any, config: Dict + ) -> None: + """ + Test the creation of a chat completion with a virtual key using the specified + Portkey client. + + This test method performs the following steps: + 1. Creates a Portkey client instance with the provided base URL, API key, trace + ID, and configuration loaded from the 'single_provider_with_virtualkey.json' + file. + 2. Calls the Portkey client's chat.completions.create method to generate a + completion. + 3. Prints the choices from the completion. + + Args: + client (Portkey): The Portkey client instance used for the test. + + Raises: + Any exceptions raised during the test. + + Note: + - Ensure that the 'single_provider_with_virtualkey.json' file exists and + contains valid configuration data. + - Modify the 'model' parameter and the 'messages' content as needed for your + use case. + """ + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + stream=True, + ) + + # print(completion.choices) + # assert("True", "True") + + # assert_matches_type(TextCompletion, completion, path=["response"]) + + # -------------------------- + # Test-3 + t3_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider_with_vk_retry_cache"): + t3_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t3_params) + async def test_method_single_provider_with_vk_retry_cache( + self, client: Any, config: Dict + ) -> None: + # 1. Make a new cache the cache + # 2. Make a cache hit and see if the response contains the data. + random_id = str(uuid4()) + metadata = self.get_metadata() + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + stream=True, + ) + # Sleeping for the cache to reflect across the workers. The cache has an + # eventual consistency and not immediate consistency. + sleep(20) + portkey_2 = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + portkey_2.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + stream=True, + ) + + # -------------------------- + # Test-4 + t4_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_with_two_apikeys"): + t4_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t4_params) + async def test_method_loadbalance_with_two_apikeys( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + # virtual_key=virtual_api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + completion = await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + max_tokens=245, + stream=True, + ) + + print(completion) + + # -------------------------- + # Test-5 + t5_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_and_fallback"): + t5_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t5_params) + async def test_method_loadbalance_and_fallback( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Say this is just a loadbalance and fallback test test", + } + ], + stream=True, + ) + + print(completion) + + # -------------------------- + # Test-6 + t6_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider"): + t6_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t6_params) + async def test_method_single_provider(self, client: Any, config: Dict) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.chat.completions.create( + messages=[{"role": "user", "content": "Say this is a test"}], + model="gpt-3.5-turbo", + stream=True, + ) + + print(completion) diff --git a/tests/test_async_complete.py b/tests/test_async_complete.py new file mode 100644 index 00000000..0a64e9bb --- /dev/null +++ b/tests/test_async_complete.py @@ -0,0 +1,436 @@ +from __future__ import annotations +import inspect + +import os +from os import walk +from typing import Any, Dict, List +import pytest +from uuid import uuid4 +from portkey_ai import AsyncPortkey +from time import sleep +from dotenv import load_dotenv +from .utils import read_json_file + + +load_dotenv(override=True) +base_url = os.environ.get("PORTKEY_BASE_URL") +api_key = os.environ.get("PORTKEY_API_KEY") +virtual_api_key = os.environ.get("OPENAI_VIRTUAL_KEY") +CONFIGS_PATH = "./tests/configs/completions" + + +def get_configs(folder_path) -> List[Dict[str, Any]]: + config_files = [] + for dirpath, _, file_names in walk(folder_path): + for f in file_names: + config_files.append(read_json_file(os.path.join(dirpath, f))) + + return config_files + + +class TestChatCompletions: + client = AsyncPortkey + parametrize = pytest.mark.parametrize("client", [client], ids=["strict"]) + models = read_json_file("./tests/models.json") + + def get_metadata(self): + return { + "case": "testing", + "function": inspect.currentframe().f_back.f_code.co_name, + "random_id": str(uuid4()), + } + + # -------------------------- + # Test-1 + t1_params = [] + t = [] + for k, v in models.items(): + for i in v["text"]: + t.append((client, k, os.environ.get(v["env_variable"]), i)) + + t1_params.extend(t) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, provider, auth, model", t1_params) + async def test_method_single_with_vk_and_provider( + self, client: Any, provider: str, auth: str, model + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + provider=f"{provider}", + Authorization=f"Bearer {auth}", + trace_id=str(uuid4()), + metadata=self.get_metadata(), + ) + + await portkey.completions.create( + prompt="Say this is a test", + model=model, + max_tokens=245, + ) + + # -------------------------- + # Test -2 + t2_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_with_basic_config"): + t2_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t2_params) + async def test_method_single_with_basic_config( + self, client: Any, config: Dict + ) -> None: + """ + Test the creation of a chat completion with a virtual key using the specified + Portkey client. + + This test method performs the following steps: + 1. Creates a Portkey client instance with the provided base URL, API key, + trace ID,and configuration loaded from the + 'single_provider_with_virtualkey.json' file. + 2. Calls the Portkey client's completions.create method to generate a completion + 3. Prints the choices from the completion. + + Args: + client (Portkey): The Portkey client instance used for the test. + + Raises: + Any exceptions raised during the test. + + Note: + - Ensure that the 'single_provider_with_virtualkey.json' file exists and + contains valid configuration data. + - Modify the 'model' parameter and the 'messages' content as needed for your + use case. + """ + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + await portkey.completions.create( + prompt="Say this is a test", + ) + + # print(completion.choices) + # assert("True", "True") + + # assert_matches_type(TextCompletion, completion, path=["response"]) + + # -------------------------- + # Test-3 + t3_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider_with_vk_retry_cache"): + t3_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t3_params) + async def test_method_single_provider_with_vk_retry_cache( + self, client: Any, config: Dict + ) -> None: + # 1. Make a new cache the cache + # 2. Make a cache hit and see if the response contains the data. + random_id = str(uuid4()) + metadata = self.get_metadata() + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + await portkey.completions.create( + prompt="Say this is a test", + ) + # Sleeping for the cache to reflect across the workers. The cache has an + # eventual consistency and not immediate consistency. + sleep(20) + portkey_2 = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + await portkey_2.completions.create(prompt="Say this is a test") + + # -------------------------- + # Test-4 + t4_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_with_two_apikeys"): + t4_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t4_params) + async def test_method_loadbalance_with_two_apikeys( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + # virtual_key=virtual_api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + completion = await portkey.completions.create( + prompt="Say this is a test", max_tokens=245 + ) + + print(completion.choices) + + # -------------------------- + # Test-5 + t5_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_and_fallback"): + t5_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t5_params) + async def test_method_loadbalance_and_fallback( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.completions.create( + prompt="Say this is just a loadbalance and fallback test test" + ) + + print(completion.choices) + + # -------------------------- + # Test-6 + t6_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider"): + t6_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t6_params) + async def test_method_single_provider(self, client: Any, config: Dict) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.completions.create( + prompt="Say this is a test", + ) + + print(completion.choices) + + +class TestChatCompletionsStreaming: + client = AsyncPortkey + parametrize = pytest.mark.parametrize("client", [client], ids=["strict"]) + models = read_json_file("./tests/models.json") + + def get_metadata(self): + return { + "case": "testing", + "function": inspect.currentframe().f_back.f_code.co_name, + "random_id": str(uuid4()), + } + + # -------------------------- + # Test-1 + t1_params = [] + t = [] + for k, v in models.items(): + for i in v["text"]: + t.append((client, k, os.environ.get(v["env_variable"]), i)) + + t1_params.extend(t) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, provider, auth, model", t1_params) + async def test_method_single_with_vk_and_provider( + self, client: Any, provider: str, auth: str, model + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + provider=f"{provider}", + Authorization=f"Bearer {auth}", + trace_id=str(uuid4()), + metadata=self.get_metadata(), + ) + + await portkey.completions.create( + prompt="Say this is a test", model=model, max_tokens=245, stream=True + ) + + # -------------------------- + # Test -2 + t2_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_with_basic_config"): + t2_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t2_params) + async def test_method_single_with_basic_config( + self, client: Any, config: Dict + ) -> None: + """ + Test the creation of a chat completion with a virtual key using the specified + Portkey client. + + This test method performs the following steps: + 1. Creates a Portkey client instance with the provided base URL, API key, + trace ID,and configuration loaded from the + 'single_provider_with_virtualkey.json' file. + 2. Calls the Portkey client's completions.create method to generate a completion + 3. Prints the choices from the completion. + + Args: + client (Portkey): The Portkey client instance used for the test. + + Raises: + Any exceptions raised during the test. + + Note: + - Ensure that the 'single_provider_with_virtualkey.json' file exists and + contains valid configuration data. + - Modify the 'model' parameter and the 'messages' content as needed for your + use case. + """ + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + await portkey.completions.create(prompt="Say this is a test", stream=True) + + # print(completion.choices) + # assert("True", "True") + + # assert_matches_type(TextCompletion, completion, path=["response"]) + + # -------------------------- + # Test-3 + t3_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider_with_vk_retry_cache"): + t3_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t3_params) + async def test_method_single_provider_with_vk_retry_cache( + self, client: Any, config: Dict + ) -> None: + # 1. Make a new cache the cache + # 2. Make a cache hit and see if the response contains the data. + random_id = str(uuid4()) + metadata = self.get_metadata() + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + await portkey.completions.create(prompt="Say this is a test", stream=True) + # Sleeping for the cache to reflect across the workers. The cache has an + # eventual consistency and not immediate consistency. + sleep(20) + portkey_2 = client( + base_url=base_url, + api_key=api_key, + trace_id=random_id, + virtual_key=virtual_api_key, + metadata=metadata, + config=config, + ) + + await portkey_2.completions.create(prompt="Say this is a test", stream=True) + + # -------------------------- + # Test-4 + t4_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_with_two_apikeys"): + t4_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t4_params) + async def test_method_loadbalance_with_two_apikeys( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + # virtual_key=virtual_api_key, + trace_id=str(uuid4()), + metadata=self.get_metadata(), + config=config, + ) + + completion = await portkey.completions.create( + prompt="Say this is a test", max_tokens=245, stream=True + ) + + print(completion) + + # -------------------------- + # Test-5 + t5_params = [] + for i in get_configs(f"{CONFIGS_PATH}/loadbalance_and_fallback"): + t5_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t5_params) + async def test_method_loadbalance_and_fallback( + self, client: Any, config: Dict + ) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.completions.create( + prompt="Say this is just a loadbalance and fallback test test", stream=True + ) + + print(completion) + + # -------------------------- + # Test-6 + t6_params = [] + for i in get_configs(f"{CONFIGS_PATH}/single_provider"): + t6_params.append((client, i)) + + @pytest.mark.asyncio + @pytest.mark.parametrize("client, config", t6_params) + async def test_method_single_provider(self, client: Any, config: Dict) -> None: + portkey = client( + base_url=base_url, + api_key=api_key, + trace_id=str(uuid4()), + config=config, + ) + + completion = await portkey.completions.create( + prompt="Say this is a test", stream=True + ) + + print(completion)