Skip to content

Commit

Permalink
check model kwarg in tools
Browse files Browse the repository at this point in the history
  • Loading branch information
richardblythman committed Mar 27, 2024
1 parent 0a207db commit 359f163
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def count_tokens(text: str, model: str) -> int:
"prediction-online-sum-url-content",
]
TOOL_TO_ENGINE = {
"prediction-offline-sum-url-content": "gpt-4",
"prediction-online-sum-url-content": "gpt-4",
"prediction-offline-sum-url-content": "gpt-4-0125-preview",
"prediction-online-sum-url-content": "gpt-4-0125-preview",
}


Expand Down Expand Up @@ -977,7 +977,7 @@ def fetch_additional_information(
google_api_key: str,
google_engine: str,
nlp,
engine: str = "gpt-3.5-turbo",
engine: str = "gpt-4-0125-preview",
temperature: float = 1.0,
max_compl_tokens: int = 500,
) -> str:
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
nlp = spacy.load("en_core_web_sm")

# Get the LLM engine to be used
engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")

# Extract the event question from the prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
}
MAX_TOKENS = {
"gpt-3.5-turbo-0125": 16385,
"gpt-4": 8192,
"gpt-4-0125-preview": 8192,
}
ALLOWED_TOOLS = [
"prediction-request-rag",
]
TOOL_TO_ENGINE = {tool: "gpt-3.5-turbo-0125" for tool in ALLOWED_TOOLS}
TOOL_TO_ENGINE = {tool: "gpt-4-0125-preview" for tool in ALLOWED_TOOLS}
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
DEFAULT_NUM_QUERIES = defaultdict(lambda: 3)
NUM_URLS_PER_QUERY = 5
Expand Down Expand Up @@ -596,7 +596,8 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
additional_information, counter_callback = fetch_additional_information(
client=client,
prompt=prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,8 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
google_engine_id = api_keys.get("google_engine_id", None)
temperature = kwargs.get("temperature", DEFAULT_OPENAI_SETTINGS["temperature"])
max_tokens = kwargs.get("max_tokens", DEFAULT_OPENAI_SETTINGS["max_tokens"])
engine = TOOL_TO_ENGINE[tool]

engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
}
MAX_TOKENS = {
"gpt-3.5-turbo-0125": 16385,
"gpt-4": 8192,
"gpt-4-0125-preview": 8192,
}
ALLOWED_TOOLS = [
"prediction-url-cot",
]
TOOL_TO_ENGINE = {tool: "gpt-3.5-turbo-0125" for tool in ALLOWED_TOOLS}
TOOL_TO_ENGINE = {tool: "gpt-4-0125-preview" for tool in ALLOWED_TOOLS}
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
DEFAULT_NUM_QUERIES = defaultdict(lambda: 3)
NUM_URLS_PER_QUERY = 5
Expand Down Expand Up @@ -429,7 +429,7 @@ def fetch_additional_information(
def adjust_doc_tokens(
doc: Document,
max_tokens: int,
engine: str = "gpt-3.5-turbo"
engine: str = "gpt-4-0125-preview"
) -> Document:
"""Adjust the number of tokens in the document."""
if count_tokens(doc.text, engine) > max_tokens:
Expand Down Expand Up @@ -532,8 +532,8 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

engine = TOOL_TO_ENGINE[tool]

engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
# fetch additional information from the web
additional_information, counter_callback = fetch_additional_information(
client=client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ def count_tokens(text: str, model: str) -> int:
"temperature": 0.,
}
MAX_TOKENS = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"gpt-3.5-turbo-0125": 4096,
"gpt-4-0125-preview": 8192,
}
ALLOWED_TOOLS = [
"prediction-offline-sme",
"prediction-online-sme",
]
TOOL_TO_ENGINE = {
"prediction-offline-sme": "gpt-3.5-turbo",
"prediction-online-sme": "gpt-3.5-turbo",
"prediction-offline-sme": "gpt-4-0125-preview",
"prediction-online-sme": "gpt-4-0125-preview",
}

PREDICTION_PROMPT = """
Expand Down Expand Up @@ -435,7 +435,8 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")

try:
sme, sme_introduction, counter_callback = get_sme_role(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def count_tokens(text: str, model: str) -> int:
]

TOOL_TO_ENGINE = {
"strong-sme-generator": "gpt-4",
"normal-sme-generator": "gpt-3.5-turbo",
"strong-sme-generator": "gpt-4-0125-preview",
"normal-sme-generator": "gpt-3.5-turbo-0125",
}

SME_GENERATION_SYSTEM_PROMPT = """
Expand Down Expand Up @@ -125,7 +125,8 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
if tool not in ALLOWED_TOOLS:
raise ValueError(f"tool must be one of {ALLOWED_TOOLS}")

engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")

market_question = SME_GENERATION_MARKET_PROMPT.format(question=prompt)
system_prompt = SME_GENERATION_SYSTEM_PROMPT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
if tool not in ALLOWED_TOOLS:
raise ValueError(f"TOOL {tool} is not supported.")

model = TOOL_TO_ENGINE[tool]

model = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {model}")
queries, counter_callback = generate_subqueries(query=prompt, limit=initial_subqueries_limit, api_key=openai_api_key, model=model, counter_callback=counter_callback)
queries, counter_callback = rerank_subqueries(queries=queries, goal=prompt, api_key=openai_api_key, model=model, counter_callback=counter_callback)
queries = queries[:subqueries_limit]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def count_tokens(text: str, model: str) -> int:
"deepmind-optimization",
]
TOOL_TO_ENGINE = {
"deepmind-optimization-strong": "gpt-4",
"deepmind-optimization": "gpt-3.5-turbo",
"deepmind-optimization-strong": "gpt-4-0125-preview",
"deepmind-optimization": "gpt-3.5-turbo-0125",
}

PREDICTION_PROMPT_INSTRUCTIONS = """
Expand Down Expand Up @@ -222,7 +222,7 @@ def prompt_engineer(
init_instructions,
instructions_format,
iterations=3,
model_name="gpt-3.5-turbo",
model_name="gpt-4-0125-preview",
):

llm = OpenAILLM(model_name=model_name, openai_api_key=openai_api_key)
Expand Down Expand Up @@ -402,7 +402,8 @@ def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]:
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
additional_information = fetch_additional_information(
prompt=prompt,
engine=engine,
Expand Down
10 changes: 6 additions & 4 deletions packages/valory/customs/prediction_request/prediction_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def count_tokens(text: str, model: str) -> int:
"prediction-online-summarized-info",
]
MAX_TOKENS = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"gpt-3.5-turbo-0125": 4096,
"gpt-4-0125-preview": 8192,
}
TOOL_TO_ENGINE = {tool: "gpt-3.5-turbo" for tool in ALLOWED_TOOLS}
TOOL_TO_ENGINE = {tool: "gpt-4-0125-preview" for tool in ALLOWED_TOOLS}
# the default number of URLs to fetch online information for
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
DEFAULT_NUM_URLS["prediction-online-summarized-info"] = 7
Expand Down Expand Up @@ -436,7 +436,9 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")

if tool.startswith("prediction-online"):
additional_information, counter_callback = fetch_additional_information(
prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def count_tokens(text: str, model: str) -> int:
"prediction-sentence-embedding-bold",
]
TOOL_TO_ENGINE = {
"prediction-sentence-embedding-conservative": "gpt-3.5-turbo",
"prediction-sentence-embedding-bold": "gpt-4",
"prediction-sentence-embedding-conservative": "gpt-3.5-turbo-0125",
"prediction-sentence-embedding-bold": "gpt-4-0125-preview",
}


Expand Down Expand Up @@ -1040,7 +1040,7 @@ def fetch_additional_information(
google_api_key: str,
google_engine: str,
nlp,
engine: str = "gpt-3.5-turbo",
engine: str = "gpt-4-0125-preview",
temperature: float = 0.5,
max_compl_tokens: int = 500,
) -> str:
Expand All @@ -1053,7 +1053,7 @@ def fetch_additional_information(
google_api_key (str): The API key for the Google service.
google_engine (str): The Google engine to be used.
temperature (float): The temperature parameter for the engine.
engine (str): The openai engine. Defaults to "gpt-3.5-turbo".
engine (str): The openai engine. Defaults to "gpt-4-0125-preview".
temperature (float): The temperature parameter for the engine. Defaults to 1.0.
max_compl_tokens (int): The maximum number of tokens for the engine's response.
Expand Down Expand Up @@ -1142,7 +1142,8 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
nlp = spacy.load("en_core_web_md")

# Get the LLM engine to be used
engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")

# Extract the event question from the prompt
event_question = re.search(r"\"(.+?)\"", prompt).group(1)
Expand All @@ -1163,7 +1164,7 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
# Fetch additional information
additional_information = fetch_additional_information(
event_question=event_question,
engine="gpt-3.5-turbo",
engine="gpt-4-0125-preview",
temperature=0.5,
max_compl_tokens=max_compl_tokens,
nlp=nlp,
Expand Down

0 comments on commit 359f163

Please sign in to comment.