Skip to content

Commit

Permalink
add unit test and key rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
victorpolisetty committed Jun 23, 2024
1 parent 16093a1 commit 35ed92b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
42 changes: 40 additions & 2 deletions packages/victorpolisetty/customs/dalle_request/dalle_request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
from typing import Any, Dict, Optional, Tuple
import functools
from typing import Any, Dict, Optional, Tuple, Callable
from openai import OpenAI
from tiktoken import encoding_for_model

client: Optional[OpenAI] = None
MechResponse = Tuple[str, Optional[str], Optional[Dict[str, Any]], Any, Any]


def with_key_rotation(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs) -> MechResponse:
api_keys = kwargs["api_keys"]
retries_left: Dict[str, int] = api_keys.max_retries()

def execute() -> MechResponse:
"""Retry the function with a new key."""
try:
result = func(*args, **kwargs)
# Ensure the result is a tuple and has the correct length
if isinstance(result, tuple) and len(result) == 4:
return result + (api_keys,)
else:
raise ValueError("Function did not return a valid MechResponse tuple.")
except openai.error.RateLimitError as e:
# try with a new key again
if retries_left["openai"] <= 0 and retries_left["openrouter"] <= 0:
raise e
retries_left["openai"] -= 1
retries_left["openrouter"] -= 1
api_keys.rotate("openai")
api_keys.rotate("openrouter")
return execute()
except Exception as e:
return str(e), "", None, None, api_keys

mech_response = execute()
return mech_response

return wrapper


class OpenAIClientManager:
"""Client context manager for OpenAI."""
Expand All @@ -22,6 +58,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
client.close()
client = None


def count_tokens(text: str, model: str) -> int:
"""Count the number of tokens in a text."""
enc = encoding_for_model(model)
Expand All @@ -37,12 +74,13 @@ def count_tokens(text: str, model: str) -> int:
ENGINES = {
"text-to-image": ["-2", "-3"],
}
ALLOWED_MODELS = [PREFIX]
ALLOWED_TOOLS = [PREFIX + value for value in ENGINES["text-to-image"]]
ALLOWED_SIZE = ["1024x1024", "1024x1792", "1792x1024"]
ALLOWED_QUALITY = ["standard", "hd"]


# @with_key_rotation
@with_key_rotation
def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]:
"""Run the task"""
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List, Any

from packages.gnosis.customs.omen_tools import omen_buy_sell
from packages.victorpolisetty.customs.dalle_request import dalle_request
from packages.napthaai.customs.prediction_request_rag import prediction_request_rag
from packages.napthaai.customs.prediction_request_rag_cohere import (
prediction_request_rag_cohere,
Expand Down Expand Up @@ -175,3 +176,13 @@ def _validate_response(self, response: Any) -> None:
super()._validate_response(response)
expected_num_tx_params = 2
assert len(response[2].keys()) == expected_num_tx_params

class TestDALLEGeneration(BaseToolTest):
"""Test DALL-E Generation."""

tools = dalle_request.ALLOWED_TOOLS
models = dalle_request.ALLOWED_MODELS
prompts = [
"Generate an image of a futuristic cityscape."
]
tool_module = dalle_request

0 comments on commit 35ed92b

Please sign in to comment.