Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DALL-E request mech tool #239

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions packages/gnosis/customs/omen_tools/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ name: omen_tools
author: gnosis
version: 0.1.0
type: custom
description: Collection of tools to prepare requests for interacting with prediction markets on Omen.
description: Collection of tools to prepare requests for interacting with prediction
markets on Omen.
license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi
prediction_sum_url_content.py: bafybeieywowx265yycgf5735bw4zyabfy6ivwnntl6smxa2hicktipgeby
omen_buy_sell.py: bafybeid3zaursxt2nkm2u7x7u4wlodg2ulzlieu5xxsfxjyxzi3vbcezdm
fingerprint_ignore_patterns: []
entry_point: omen_buy_sell.py
callable: run
Expand All @@ -26,4 +27,4 @@ dependencies:
langchain_community:
version: ==0.2.1
openai:
version: ==1.30.2
version: ==1.30.2
1 change: 1 addition & 0 deletions packages/packages.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"custom/valory/prediction_langchain/0.1.0": "bafybeihhii7veepp6ovkmqjnkp6euhkwm52obabgdltdj34ikisfd7yvqi",
"custom/victorpolisetty/gemini_request/0.1.0": "bafybeig5x6b5jtanet2q5sk7er7fdzpippbvh4q5p7uxmxpriq66omjnaq",
"custom/gnosis/omen_tools/0.1.0": "bafybeifxrawgu6m3dgsxvj7jrhxzr5gwi3zjk2m4gltkr5w3hxjjbla6nu",
"custom/victorpolisetty/dalle_request/0.1.0": "bafybeicgjdvgamkgjebdrowrxdil3aghsbcm7epup6aqidikvjpmvomn6q",
"protocol/valory/acn_data_share/0.1.0": "bafybeih5ydonnvrwvy2ygfqgfabkr47s4yw3uqxztmwyfprulwfsoe7ipq",
"protocol/valory/websocket_client/0.1.0": "bafybeifjk254sy65rna2k32kynzenutujwqndap2r222afvr3zezi27mx4",
"contract/valory/agent_mech/0.1.0": "bafybeiah6b5epo2hlvzg5rr2cydgpp2waausoyrpnoarf7oa7bw33rex34",
Expand Down
2 changes: 1 addition & 1 deletion packages/valory/services/mech/service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license: Apache-2.0
fingerprint:
README.md: bafybeif7ia4jdlazy6745ke2k2x5yoqlwsgwr6sbztbgqtwvs3ndm2p7ba
fingerprint_ignore_patterns: []
agent: valory/mech:0.1.0:bafybeid2hlmwtoze3xhhqqayisdg6xzxzgf42zw7hccdaszur7qfgp5vu4
agent: valory/mech:0.1.0:bafybeih2oex4yt4mmiyarp2ivkqqfscoavyb7metchciiqmqdmv7lhyutq
number_of_agents: 4
deployment:
agent:
Expand Down
19 changes: 19 additions & 0 deletions packages/victorpolisetty/customs/dalle_request/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#
# Copyright 2024 Valory AG
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ------------------------------------------------------------------------------
18 changes: 18 additions & 0 deletions packages/victorpolisetty/customs/dalle_request/component.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: dalle_request
author: victorpolisetty
version: 0.1.0
type: custom
description: A tool that runs a prompt against the OpenAI DALL-E API.
license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeicokooiqnmkldoi5tx6zv6svtjoak5ghj5o3vsi4iliq2pvqaa6uy
dalle_request.py: bafybeicagzzicf7o6u6iotbvxfacdvntn5bpa4fptmst2iakjwaz7os2ry
fingerprint_ignore_patterns: []
entry_point: dalle_request.py
callable: run
dependencies:
openai:
version: ==1.30.2
tiktoken:
version: ==0.7.0
122 changes: 122 additions & 0 deletions packages/victorpolisetty/customs/dalle_request/dalle_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import functools
from typing import Any, Dict, Optional, Tuple, Callable
from openai import OpenAI
from tiktoken import encoding_for_model

client: Optional[OpenAI] = None
MechResponse = Tuple[str, Optional[str], Optional[Dict[str, Any]], Any, Any]


def with_key_rotation(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs) -> MechResponse:
api_keys = kwargs["api_keys"]
retries_left: Dict[str, int] = api_keys.max_retries()

def execute() -> MechResponse:
"""Retry the function with a new key."""
try:
result = func(*args, **kwargs)
# Ensure the result is a tuple and has the correct length
if isinstance(result, tuple) and len(result) == 4:
return result + (api_keys,)
else:
raise ValueError("Function did not return a valid MechResponse tuple.")
except openai.error.RateLimitError as e:
# try with a new key again
if retries_left["openai"] <= 0 and retries_left["openrouter"] <= 0:
raise e
retries_left["openai"] -= 1
retries_left["openrouter"] -= 1
api_keys.rotate("openai")
api_keys.rotate("openrouter")
return execute()
except Exception as e:
return str(e), "", None, None, api_keys

mech_response = execute()
return mech_response

return wrapper


class OpenAIClientManager:
"""Client context manager for OpenAI."""

def __init__(self, api_key: str):
self.api_key = api_key

def __enter__(self) -> OpenAI:
global client
if client is None:
client = OpenAI(api_key=self.api_key)
return client

def __exit__(self, exc_type, exc_value, traceback) -> None:
global client
if client is not None:
client.close()
client = None


def count_tokens(text: str, model: str) -> int:
"""Count the number of tokens in a text."""
enc = encoding_for_model(model)
return len(enc.encode(text))


DEFAULT_DALLE_SETTINGS = {
"size": "1024x1024",
"quality": "standard",
"n": 1,
}
PREFIX = "dall-e"
ENGINES = {
"text-to-image": ["-2", "-3"],
}
ALLOWED_MODELS = [PREFIX]
ALLOWED_TOOLS = [PREFIX + value for value in ENGINES["text-to-image"]]
ALLOWED_SIZE = ["1024x1024", "1024x1792", "1792x1024"]
ALLOWED_QUALITY = ["standard", "hd"]


@with_key_rotation
def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]:
"""Run the task"""
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
tool = kwargs["tool"]
prompt = kwargs["prompt"]
size = kwargs.get("size", DEFAULT_DALLE_SETTINGS["size"])
quality = kwargs.get("quality", DEFAULT_DALLE_SETTINGS["quality"])
n = kwargs.get("n", DEFAULT_DALLE_SETTINGS["n"])
counter_callback = kwargs.get("counter_callback", None)
if tool not in ALLOWED_TOOLS:
return (
f"Tool {tool} is not in the list of supported tools.",
None,
None,
None,
)
if size not in ALLOWED_SIZE:
return (
f"Size {size} is not in the list of supported sizes.",
None,
None,
None,
)
if quality not in ALLOWED_QUALITY:
return (
f"Quality {quality} is not in the list of supported qualities.",
None,
None,
None,
)

response = client.images.generate(
model=tool,
prompt=prompt,
size=size,
quality=quality,
n=n,
)
return response.data[0].url, prompt, None, counter_callback
11 changes: 11 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List, Any

from packages.gnosis.customs.omen_tools import omen_buy_sell
from packages.victorpolisetty.customs.dalle_request import dalle_request
from packages.napthaai.customs.prediction_request_rag import prediction_request_rag
from packages.napthaai.customs.prediction_request_rag_cohere import (
prediction_request_rag_cohere,
Expand Down Expand Up @@ -175,3 +176,13 @@ def _validate_response(self, response: Any) -> None:
super()._validate_response(response)
expected_num_tx_params = 2
assert len(response[2].keys()) == expected_num_tx_params

class TestDALLEGeneration(BaseToolTest):
"""Test DALL-E Generation."""

tools = dalle_request.ALLOWED_TOOLS
models = dalle_request.ALLOWED_MODELS
prompts = [
"Generate an image of a futuristic cityscape."
]
tool_module = dalle_request
Loading