Skip to content

Commit

Permalink
Merge pull request #210 from valory-xyz/feat/tool-param-component
Browse files Browse the repository at this point in the history
feat: add tool params
  • Loading branch information
richardblythman authored Apr 24, 2024
2 parents a823a98 + 0db6879 commit 80051bd
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibt7f7crtwvmkg7spy3jhscmlqltvyblzp32g6gj44v7tlo5lycuq
prediction_request_rag.py: bafybeie4rafsx7m7mfkdgd2n2ovfm3gwk5ybjr5cz6s7r3oz545ag2wwuy
prediction_request_rag.py: bafybeibtvuddvbhjlyd4sbk7rwz4mcsr4hiigfgrpdhzwa6vn6bhb6fboy
fingerprint_ignore_patterns: []
entry_point: prediction_request_rag.py
callable: run
params:
default_model: claude-3-sonnet-20240229
dependencies:
google-api-python-client:
version: ==2.95.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ def embeddings(self, model, input):
"prediction-request-rag",
]
ALLOWED_MODELS = list(LLM_SETTINGS.keys())
DEFAULT_MODEL = "claude-3-haiku-20240307"
TOOL_TO_ENGINE = {tool: DEFAULT_MODEL for tool in ALLOWED_TOOLS}
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
DEFAULT_NUM_QUERIES = defaultdict(lambda: 3)
NUM_URLS_PER_QUERY = 5
Expand Down Expand Up @@ -299,11 +297,11 @@ def count_tokens(text: str, model: str) -> int:

def multi_queries(
prompt: str,
engine: str,
model: str,
num_queries: int,
counter_callback: Optional[Callable[[int, int, str], None]] = None,
temperature: Optional[float] = LLM_SETTINGS[DEFAULT_MODEL]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS[DEFAULT_MODEL]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["gpt-4-0125-preview"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["gpt-4-0125-preview"]["default_max_tokens"],
) -> List[str]:
"""Generate multiple queries for fetching information from the web."""
url_query_prompt = URL_QUERY_PROMPT.format(
Expand All @@ -316,7 +314,7 @@ def multi_queries(
]

response = client.completions(
model=engine,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -325,7 +323,7 @@ def multi_queries(
counter_callback(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
model=engine,
model=model,
token_counter=count_tokens,
)
queries = parser_query_response(response.content, num_queries=num_queries)
Expand Down Expand Up @@ -539,15 +537,15 @@ def recursive_character_text_splitter(text, max_tokens, overlap):

def fetch_additional_information(
prompt: str,
engine: str,
model: str,
google_api_key: Optional[str],
google_engine_id: Optional[str],
counter_callback: Optional[Callable[[int, int, str], None]] = None,
source_links: Optional[List[str]] = None,
num_urls: Optional[int] = DEFAULT_NUM_URLS,
num_queries: Optional[int] = DEFAULT_NUM_QUERIES,
temperature: Optional[float] = LLM_SETTINGS[DEFAULT_MODEL]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS[DEFAULT_MODEL]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["gpt-4-0125-preview"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["gpt-4-0125-preview"]["default_max_tokens"],
) -> Tuple[str, Callable[[int, int, str], None]]:
"""Fetch additional information to help answer the user prompt."""

Expand All @@ -556,7 +554,7 @@ def fetch_additional_information(
try:
queries, counter_callback = multi_queries(
prompt=prompt,
engine=engine,
model=model,
num_queries=num_queries,
counter_callback=counter_callback,
temperature=temperature,
Expand Down Expand Up @@ -674,14 +672,13 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
):
tool = kwargs["tool"]
model = kwargs.get("model", TOOL_TO_ENGINE[tool])
prompt = extract_question(kwargs["prompt"])
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
model = kwargs.get("model")
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[engine]["default_max_tokens"]
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
temperature = kwargs.get("temperature", LLM_SETTINGS[engine]["temperature"])
temperature = kwargs.get("temperature", LLM_SETTINGS[model]["temperature"])
num_urls = kwargs.get("num_urls", DEFAULT_NUM_URLS[tool])
num_queries = kwargs.get("num_queries", DEFAULT_NUM_QUERIES[tool])
counter_callback = kwargs.get("counter_callback", None)
Expand All @@ -699,7 +696,7 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:

additional_information, counter_callback = fetch_additional_information(
prompt=prompt,
engine=engine,
model=model,
google_api_key=google_api_key,
google_engine_id=google_engine_id,
counter_callback=counter_callback,
Expand All @@ -723,7 +720,7 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
]

response = client.completions(
model=engine,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -733,7 +730,7 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
counter_callback(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
model=engine,
model=model,
token_counter=count_tokens,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeiekjzoy2haayvkiwhb2u2epflpqxticud34mma3gdhfzgu36lxwiq
prediction_request_rag_cohere.py: bafybeidt7vvlrapi2y4b4ytqqi2s2ro5oi5xfc62a2rftkwbat2fmw5hme
prediction_request_rag_cohere.py: bafybeig4oq3tdjuz2la2pz232u5m7347q7gplu5pw4vebbxteuiqw6hh3u
fingerprint_ignore_patterns: []
entry_point: prediction_request_rag_cohere.py
callable: run
params:
default_model: cohere/command-r-plus
dependencies:
google-api-python-client:
version: ==2.95.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ def embeddings(self, model, input):
"prediction-request-rag-cohere",
]
ALLOWED_MODELS = list(LLM_SETTINGS.keys())
DEFAULT_MODEL = "cohere/command-r-plus"
TOOL_TO_ENGINE = {tool: DEFAULT_MODEL for tool in ALLOWED_TOOLS}
DEFAULT_NUM_URLS = defaultdict(lambda: 3)
DEFAULT_NUM_QUERIES = defaultdict(lambda: 3)
NUM_URLS_PER_QUERY = 5
Expand Down Expand Up @@ -274,11 +272,11 @@ def count_tokens(text: str, model: str) -> int:

def multi_queries(
prompt: str,
engine: str,
model: str,
num_queries: int,
counter_callback: Optional[Callable[[int, int, str], None]] = None,
temperature: Optional[float] = LLM_SETTINGS[DEFAULT_MODEL]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS[DEFAULT_MODEL]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["cohere/command-r-plus"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["cohere/command-r-plus"]["default_max_tokens"],
) -> List[str]:
"""Generate multiple queries for fetching information from the web."""
url_query_prompt = URL_QUERY_PROMPT.format(
Expand All @@ -291,7 +289,7 @@ def multi_queries(
]

response = client.completions(
model=engine,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -300,7 +298,7 @@ def multi_queries(
counter_callback(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
model=engine,
model=model,
token_counter=count_tokens,
)
queries = parser_query_response(response.content, num_queries=num_queries)
Expand Down Expand Up @@ -514,15 +512,15 @@ def recursive_character_text_splitter(text, max_tokens, overlap):

def fetch_additional_information(
prompt: str,
engine: str,
model: str,
google_api_key: Optional[str],
google_engine_id: Optional[str],
counter_callback: Optional[Callable[[int, int, str], None]] = None,
source_links: Optional[List[str]] = None,
num_urls: Optional[int] = DEFAULT_NUM_URLS,
num_queries: Optional[int] = DEFAULT_NUM_QUERIES,
temperature: Optional[float] = LLM_SETTINGS[DEFAULT_MODEL]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS[DEFAULT_MODEL]["default_max_tokens"],
temperature: Optional[float] = LLM_SETTINGS["cohere/command-r-plus"]["temperature"],
max_tokens: Optional[int] = LLM_SETTINGS["cohere/command-r-plus"]["default_max_tokens"],
) -> Tuple[str, Callable[[int, int, str], None]]:
"""Fetch additional information to help answer the user prompt."""

Expand All @@ -531,7 +529,7 @@ def fetch_additional_information(
try:
queries, counter_callback = multi_queries(
prompt=prompt,
engine=engine,
model=model,
num_queries=num_queries,
counter_callback=counter_callback,
temperature=temperature,
Expand Down Expand Up @@ -648,14 +646,13 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
kwargs["api_keys"], kwargs["llm_provider"], embedding_provider="openai"
):
tool = kwargs["tool"]
model = kwargs.get("model", TOOL_TO_ENGINE[tool])
model = kwargs.get("model")
prompt = extract_question(kwargs["prompt"])
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
print(f"MODEL: {model}")
max_tokens = kwargs.get(
"max_tokens", LLM_SETTINGS[engine]["default_max_tokens"]
"max_tokens", LLM_SETTINGS[model]["default_max_tokens"]
)
temperature = kwargs.get("temperature", LLM_SETTINGS[engine]["temperature"])
temperature = kwargs.get("temperature", LLM_SETTINGS[model]["temperature"])
num_urls = kwargs.get("num_urls", DEFAULT_NUM_URLS[tool])
num_queries = kwargs.get("num_queries", DEFAULT_NUM_QUERIES[tool])
counter_callback = kwargs.get("counter_callback", None)
Expand All @@ -673,7 +670,7 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:

additional_information, counter_callback = fetch_additional_information(
prompt=prompt,
engine=engine,
model=model,
google_api_key=google_api_key,
google_engine_id=google_engine_id,
counter_callback=counter_callback,
Expand All @@ -697,7 +694,7 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
]

response = client.completions(
model=engine,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -707,7 +704,7 @@ def run(**kwargs) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]], Any]:
counter_callback(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
model=engine,
model=model,
token_counter=count_tokens,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeib36ew6vbztldut5xayk5553rylrq7yv4cpqyhwc5ktvd4cx67vwu
prediction_request_reasoning.py: bafybeibyncgeeyrlcqhdbsdzzama7aah44jzedqo3zcplvuc45goykwjli
prediction_request_reasoning.py: bafybeidiabgnlc453spgrdn7rhhl2xc3aa6zqeukkw2bthndbugtjf6bya
fingerprint_ignore_patterns: []
entry_point: prediction_request_reasoning.py
callable: run
params:
default_model: claude-3-sonnet-20240229
dependencies:
google-api-python-client:
version: ==2.95.0
Expand Down
Loading

0 comments on commit 80051bd

Please sign in to comment.