Skip to content

Commit

Permalink
chore: polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoreira-valory committed Feb 12, 2024
1 parent 54dd3d2 commit 7c021dd
Showing 1 changed file with 26 additions and 73 deletions.
99 changes: 26 additions & 73 deletions tools/resolve_market/resolve_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,23 @@
#
# ------------------------------------------------------------------------------

"""This module implements a Mech tool for binary predictions."""
"""This module implements a Mech tool for binary predictions.
This module tries to mimic the current logic on the market-creator service
(https://github.com/valory-xyz/market-creator) for resolving closed markets.
"""

import json
import os
import requests
import logging
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from heapq import nlargest
from itertools import islice
from string import punctuation
from typing import Any, Dict, Generator, List, Optional, Tuple, Callable
from string import Template


import tiktoken
from openai import OpenAI

import requests
from typing import TypedDict

import spacy
import logging

from spacy import Language
from spacy.cli import download
from spacy.lang.en import STOP_WORDS
from spacy.tokens import Doc, Span

from openai import OpenAI

client: Optional[OpenAI] = None
Expand All @@ -66,27 +56,13 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
client = None

DEFAULT_OPENAI_SETTINGS = {
"max_tokens": 500,
"max_tokens": 700,
"temperature": 0.7,
}
MAX_TOKENS = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
}
ALLOWED_TOOLS = [
"close_market",
]
TOOL_TO_ENGINE = {tool: "gpt-4" for tool in ALLOWED_TOOLS}
# the default number of URLs to fetch online information for
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
DEFAULT_NUM_URLS["close_market"] = 7
# the default number of words to fetch online information for
DEFAULT_NUM_WORDS: Dict[str, Optional[int]] = defaultdict(lambda: 300)
DEFAULT_NUM_WORDS["close_market"] = None
# how much of the initial content will be kept during summarization
DEFAULT_COMPRESSION_FACTOR = 0.05
# the vocabulary to use for the summarization
DEFAULT_VOCAB = "en_core_web_sm"

NEWSAPI_ENDPOINT = "https://newsapi.org/v2"
TOP_HEADLINES = "top-headlines"
Expand All @@ -95,10 +71,6 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
ARTICLE_LIMIT = 1_000
ADDITIONAL_INFO_LIMIT = 5_000
HTTP_OK = 200
ANSWER_NO, ANSWER_YES = (
"0x0000000000000000000000000000000000000000000000000000000000000000",
"0x0000000000000000000000000000000000000000000000000000000000000001",
)

URL_QUERY_PROMPT_TEMPLATE = """
You are an LLM inside a multi-agent system that takes in a prompt of a user requesting a binary outcome
Expand Down Expand Up @@ -173,17 +145,18 @@ class CloseMarketBehaviourMock:

params: Object
context: Object
kwargs: Dict[str, Any]

def __init__(
self,
market_closing_newsapi_api_key: str,
newsapi_endpoint: str,
**kwargs
):
self.kwargs=kwargs
self.context = Object()
self.context.logger = logging.getLogger(__name__)
self.params = Object()
self.params.market_closing_newsapi_api_key = market_closing_newsapi_api_key
self.params.newsapi_endpoint = newsapi_endpoint
self.params.market_closing_newsapi_api_key = kwargs.get("api_keys", {})["newsapi"]
self.params.newsapi_endpoint = NEWSAPI_ENDPOINT

def get_http_response(
self,
Expand All @@ -207,6 +180,7 @@ def _parse_llm_output(
) -> Optional[Dict[str, Any]]:
"""Parse the llm output to json."""
try:
output = output.replace("`", "")
json_data = json.loads(output)
if required_fields is not None:
for field in required_fields:
Expand Down Expand Up @@ -241,7 +215,7 @@ def do_llm_request(self, **kwargs) -> str:
temperature = kwargs.get("temperature", DEFAULT_OPENAI_SETTINGS["temperature"])
counter_callback = kwargs.get("counter_callback", None)
prompt = kwargs.get("prompt")
engine = "gpt-4"
engine = TOOL_TO_ENGINE.get(kwargs["tool"])
moderation_result = client.moderations.create(input=prompt)
if moderation_result.results[0].flagged:
return "Moderation flagged the prompt as in violation of terms.", None, None
Expand All @@ -265,7 +239,7 @@ def do_llm_request(self, **kwargs) -> str:
return res


def _get_answer(self, question: str, **kwargs) -> Optional[str]:
def _get_answer(self, question: str) -> Optional[str]:
"""Get an answer for the provided questions"""

# An initial query is made to Newsapi to detect ratelimit issue
Expand All @@ -286,7 +260,8 @@ def _get_answer(self, question: str, **kwargs) -> Optional[str]:
}

prompt = URL_QUERY_PROMPT_TEMPLATE.format(**prompt_values)
kwargs1 = {"prompt": prompt, **kwargs}
kwargs1 = {"prompt": prompt}
kwargs1.update(self.kwargs)
llm_response_message = self.do_llm_request(**kwargs1)

result_str = llm_response_message.value.replace("OUTPUT:", "").rstrip().lstrip()
Expand Down Expand Up @@ -320,7 +295,8 @@ def _get_answer(self, question: str, **kwargs) -> Optional[str]:

# llm request message
prompt = OUTCOME_PROMPT_TEMPLATE.format(**prompt_values)
kwargs2 = {"prompt": prompt, **kwargs}
kwargs2 = {"prompt": prompt}
kwargs2.update(self.kwargs)
llm_response_message = self.do_llm_request(**kwargs2)

result_str = llm_response_message.value.replace("OUTPUT:", "").rstrip().lstrip()
Expand All @@ -338,7 +314,7 @@ def _get_answer(self, question: str, **kwargs) -> Optional[str]:

def _get_news(
self, query: str
) -> Generator[None, None, Optional[List[Dict[str, Any]]]]:
) -> List[Dict[str, Any]]:
"""Auxiliary method to collect data from endpoint."""

headers = {"X-Api-Key": self.params.market_closing_newsapi_api_key}
Expand Down Expand Up @@ -372,48 +348,25 @@ def _get_news(
def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""
tool = kwargs["tool"]
question = kwargs["question"]
max_tokens = kwargs.get("max_tokens", DEFAULT_OPENAI_SETTINGS["max_tokens"])
temperature = kwargs.get("temperature", DEFAULT_OPENAI_SETTINGS["temperature"])
num_urls = kwargs.get("num_urls", DEFAULT_NUM_URLS[tool])
num_words = kwargs.get("num_words", DEFAULT_NUM_WORDS[tool])
compression_factor = kwargs.get("compression_factor", DEFAULT_COMPRESSION_FACTOR)
vocab = kwargs.get("vocab", DEFAULT_VOCAB)
counter_callback = kwargs.get("counter_callback", None)
api_keys = kwargs.get("api_keys", {})
google_api_key = api_keys.get("google_api_key", None)
google_engine_id = api_keys.get("google_engine_id", None)

if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

engine = TOOL_TO_ENGINE[tool]

market_behavior = CloseMarketBehaviourMock(
market_closing_newsapi_api_key=api_keys["newsapi"],
newsapi_endpoint=NEWSAPI_ENDPOINT,
)

kwargs.pop('question', None)
v = market_behavior._get_answer(question, **kwargs)
return v
market_behavior = CloseMarketBehaviourMock(**kwargs)
question = kwargs.pop('question', None)
result = market_behavior._get_answer(question)
return result


if __name__ == "__main__":
"""Example usage"""

newsapi_api_key = os.getenv("NEWSAPI_API_KEY")
openai_api_key = os.getenv("OPENAI_API_KEY")

my_kwargs = {
"tool": "close_market",
"question": "Will a cease-fire be implemented in the Gaza Strip by 5 February 2024?",
"max_tokens": 100,
"temperature": 0.7,
"num_urls": 5,
"num_words": 200,
"compression_factor": 2.0,
"vocab": ["word1", "word2", "word3"],
"counter_callback": None,
"api_keys": {
"newsapi": newsapi_api_key,
"openai": openai_api_key,
Expand Down

0 comments on commit 7c021dd

Please sign in to comment.