Skip to content

Commit

Permalink
fix: tool degradation
Browse files Browse the repository at this point in the history
  • Loading branch information
0xArdi committed Feb 14, 2024
1 parent f92a447 commit 5edd83a
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tools/native_transfer_request/native_transfer_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def native_transfer(
# parse the response to get the transaction object string itself
parsed_txs = ast.literal_eval(response)
except SyntaxError:
return response, None, None
return response, None, None, None

# build the transaction object, unknowns are referenced from parsed_txs
transaction = {
Expand Down
12 changes: 6 additions & 6 deletions tools/openai_request/openai_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def count_tokens(text: str, model: str) -> int:
ALLOWED_TOOLS = [PREFIX + value for values in ENGINES.values() for value in values]


def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]:
"""Run the task"""
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
max_tokens = kwargs.get("max_tokens", DEFAULT_OPENAI_SETTINGS["max_tokens"])
Expand All @@ -70,12 +70,12 @@ def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
tool = kwargs["tool"]
counter_callback = kwargs.get("counter_callback", None)
if tool not in ALLOWED_TOOLS:
return f"Tool {tool} is not in the list of supported tools.", None, None
return f"Tool {tool} is not in the list of supported tools.", None, None, None

engine = tool.replace(PREFIX, "")
moderation_result = client.moderations.create(input=prompt)
if moderation_result.results[0].flagged:
return "Moderation flagged the prompt as in violation of terms.", None, None
return "Moderation flagged the prompt as in violation of terms.", None, None, None

if engine in ENGINES["chat"]:
messages = [
Expand All @@ -91,9 +91,9 @@ def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
timeout=120,
stop=None,
)
return response.choices[0].message.content, prompt, None
return response.choices[0].message.content, prompt, None, None
response = client.completions.create(
engine=engine,
model=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -102,4 +102,4 @@ def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
timeout=120,
presence_penalty=0,
)
return response.choices[0].text, prompt, counter_callback
return response.choices[0].text, prompt, None, counter_callback
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def fetch_additional_information(
return "\n".join(["- " + text for text in texts])


def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]:
"""Run the task"""
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
tool = kwargs["tool"]
Expand Down
10 changes: 5 additions & 5 deletions tools/prediction_request/prediction_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def fetch_additional_information(
num_words: Optional[int],
counter_callback: Optional[Callable] = None,
source_links: Optional[List[str]] = None,
) -> str:
) -> Tuple[str, Any]:
"""Fetch additional information."""
url_query_prompt = URL_QUERY_PROMPT.format(user_prompt=prompt)
moderation_result = client.moderations.create(input=url_query_prompt)
Expand Down Expand Up @@ -363,7 +363,7 @@ def summarize(text: str, compression_factor: float, vocab: str) -> str:
return summary_text


def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
tool = kwargs["tool"]
Expand Down Expand Up @@ -409,7 +409,7 @@ def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
)
moderation_result = client.moderations.create(input=prediction_prompt)
if moderation_result.results[0].flagged:
return "Moderation flagged the prompt as in violation of terms.", None, None
return "Moderation flagged the prompt as in violation of terms.", None, None, None
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prediction_prompt},
Expand All @@ -430,5 +430,5 @@ def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
model=engine,
token_counter=count_tokens,
)
return response.choices[0].message.content, prediction_prompt, counter_callback
return response.choices[0].message.content, prediction_prompt, None
return response.choices[0].message.content, prediction_prompt, None, counter_callback
return response.choices[0].message.content, prediction_prompt, None, None
8 changes: 4 additions & 4 deletions tools/prediction_request_claude/prediction_request_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from googleapiclient.discovery import build

NUM_URLS_EXTRACT = 5
DEFAULT_NUM_WORDS: Dict[str, Optional[int]] = defaultdict(lambda: 300)
DEFAULT_NUM_WORDS = 300
DEFAULT_OPENAI_SETTINGS = {
"max_tokens": 500,
"temperature": 0.7,
Expand Down Expand Up @@ -259,7 +259,7 @@ def fetch_additional_information(
return "\n".join(["- " + text for text in texts]), None


def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""Run the task"""
tool = kwargs["tool"]
prompt = kwargs["prompt"]
Expand Down Expand Up @@ -308,6 +308,6 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
output_prompt=completion.completion,
token_counter=count_tokens,
)
return completion.completion, prediction_prompt, counter_callback
return completion.completion, prediction_prompt, None, counter_callback

return completion.completion, prediction_prompt, None
return completion.completion, prediction_prompt, None, None
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def fetch_additional_information(
return additional_informations


def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
"""
Run the task with the given arguments.
Expand Down Expand Up @@ -1187,7 +1187,7 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
# Perform moderation
moderation_result = client.moderations.create(input=prediction_prompt)
if moderation_result.results[0].flagged:
return "Moderation flagged the prompt as in violation of terms.", None, None
return "Moderation flagged the prompt as in violation of terms.", None, None, None

# Create messages for the OpenAI engine
messages = [
Expand All @@ -1205,4 +1205,4 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
timeout=150,
stop=None,
)
return response.choices[0].message.content, prediction_prompt, None
return response.choices[0].message.content, prediction_prompt, None, None
5 changes: 2 additions & 3 deletions tools/prediction_request_sme/prediction_request_sme.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ def fetch_additional_information(
json_data["queries"],
api_key=google_api_key,
engine=google_engine,
num_urls=num_urls,
)
texts = extract_texts(urls, num_words)
else:
Expand Down Expand Up @@ -433,5 +432,5 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
model=engine,
token_counter=count_tokens,
)
return response.choices[0].message.content, prediction_prompt, counter_callback
return response.choices[0].message.content, prediction_prompt, None
return response.choices[0].message.content, prediction_prompt, None, counter_callback
return response.choices[0].message.content, prediction_prompt, None, None
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ def fetch_additional_information(
return additional_informations


def run(**kwargs) -> Tuple[str, Optional[Dict[str, Any]], Any]:
def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""
Run the task with the given arguments.
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def run(**kwargs) -> Tuple[str, Optional[Dict[str, Any]], Any]:
# Perform moderation
moderation_result = client.moderations.create(input=prediction_prompt)
if moderation_result.results[0].flagged:
return "Moderation flagged the prompt as in violation of terms.", None, None
return "Moderation flagged the prompt as in violation of terms.", None, None, None

# Create messages for the OpenAI engine
messages = [
Expand All @@ -1174,4 +1174,4 @@ def run(**kwargs) -> Tuple[str, Optional[Dict[str, Any]], Any]:
stop=None,
)
print(f"RESPONSE: {response}")
return response.choices[0].message.content, None, None
return response.choices[0].message.content, None, None, None
7 changes: 4 additions & 3 deletions tools/stability_ai_request/stabilityai_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ class FinishReason(Enum):
ERROR = 2


def run(**kwargs: Any) -> Tuple[str, Optional[str], Optional[Dict[str, Any]]]:
def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""

api_key = kwargs["api_keys"]["stabilityai"]
tool = kwargs["tool"]
prompt = kwargs["prompt"]
if tool not in ALLOWED_TOOLS:
return f"Tool {tool} is not in the list of supported tools.", None
return f"Tool {tool} is not in the list of supported tools.", None, None, None

# Place content moderation request here if needed
engine = tool.replace(PREFIX, "")
Expand Down Expand Up @@ -118,9 +118,10 @@ def run(**kwargs: Any) -> Tuple[str, Optional[str], Optional[Dict[str, Any]]]:
json=json_params,
)
if response.status_code == 200:
return json.dumps(response.json()), None, None
return json.dumps(response.json()), None, None, None
return (
f"Error: Non-200 response ({response.status_code}): {response.text}",
None,
None,
None
)

0 comments on commit 5edd83a

Please sign in to comment.