Skip to content

Commit

Permalink
feat: apply tool fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
0xArdi committed Apr 8, 2024
1 parent ad73254 commit 73c0d88
Show file tree
Hide file tree
Showing 20 changed files with 373 additions and 230 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeib36ew6vbztldut5xayk5553rylrq7yv4cpqyhwc5ktvd4cx67vwu
prediction_request_reasoning.py: bafybeifjm24tqil3nan37dbg4u7qm3xw3jxfu5eleg3cojmka67dwaclpa
prediction_request_reasoning.py: bafybeid3umzaz7qzxyf4pda6ffmayyhmr7fjx35bk4qzpsi33rtanddrem
fingerprint_ignore_patterns: []
entry_point: prediction_request_reasoning.py
callable: run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,114 +681,117 @@ def extract_question(prompt: str) -> str:

def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
tool = kwargs["tool"]
prompt = extract_question(kwargs["prompt"])
num_urls = kwargs.get("num_urls", DEFAULT_NUM_URLS[tool])
counter_callback = kwargs.get("counter_callback", None)
api_keys = kwargs.get("api_keys", {})
google_api_key = api_keys.get("google_api_key", None)
google_engine_id = api_keys.get("google_engine_id", None)
temperature = kwargs.get("temperature", DEFAULT_OPENAI_SETTINGS["temperature"])
max_tokens = kwargs.get("max_tokens", DEFAULT_OPENAI_SETTINGS["max_tokens"])
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

(
additional_information,
queries,
counter_callback,
) = fetch_additional_information(
client=client,
prompt=prompt,
engine=engine,
google_api_key=google_api_key,
google_engine_id=google_engine_id,
counter_callback=counter_callback,
source_links=kwargs.get("source_links", None),
num_urls=num_urls,
temperature=temperature,
max_tokens=max_tokens,
)
try:
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
tool = kwargs["tool"]
prompt = extract_question(kwargs["prompt"])
num_urls = kwargs.get("num_urls", DEFAULT_NUM_URLS[tool])
counter_callback = kwargs.get("counter_callback", None)
api_keys = kwargs.get("api_keys", {})
google_api_key = api_keys.get("google_api_key", None)
google_engine_id = api_keys.get("google_engine_id", None)
temperature = kwargs.get("temperature", DEFAULT_OPENAI_SETTINGS["temperature"])
max_tokens = kwargs.get("max_tokens", DEFAULT_OPENAI_SETTINGS["max_tokens"])
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")
if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

(
additional_information,
queries,
counter_callback,
) = fetch_additional_information(
client=client,
prompt=prompt,
engine=engine,
google_api_key=google_api_key,
google_engine_id=google_engine_id,
counter_callback=counter_callback,
source_links=kwargs.get("source_links", None),
num_urls=num_urls,
temperature=temperature,
max_tokens=max_tokens,
)

# Adjust the additional_information to fit within the token budget
adjusted_info = adjust_additional_information(
prompt=PREDICTION_PROMPT,
additional_information=additional_information,
model=engine,
)
# Adjust the additional_information to fit within the token budget
adjusted_info = adjust_additional_information(
prompt=PREDICTION_PROMPT,
additional_information=additional_information,
model=engine,
)

# Reasoning prompt
reasoning_prompt = REASONING_PROMPT.format(
user_prompt=prompt, formatted_docs=adjusted_info
)
# Reasoning prompt
reasoning_prompt = REASONING_PROMPT.format(
user_prompt=prompt, formatted_docs=adjusted_info
)

# Do reasoning
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": reasoning_prompt,
},
]
# Do reasoning
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": reasoning_prompt,
},
]

# Reasoning
response_reasoning = client.chat.completions.create(
model=engine,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
n=1,
timeout=150,
stop=None,
)
# Reasoning
response_reasoning = client.chat.completions.create(
model=engine,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
n=1,
timeout=150,
stop=None,
)

# Extract the reasoning
reasoning = response_reasoning.choices[0].message.content
# Extract the reasoning
reasoning = response_reasoning.choices[0].message.content

# Prediction prompt
prediction_prompt = PREDICTION_PROMPT.format(
user_prompt=prompt, reasoning=reasoning
)
# Prediction prompt
prediction_prompt = PREDICTION_PROMPT.format(
user_prompt=prompt, reasoning=reasoning
)

# Make the prediction
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": prediction_prompt,
},
]
# Make the prediction
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": prediction_prompt,
},
]

response = client.chat.completions.create(
model=engine,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
n=1,
timeout=150,
stop=None,
functions=[Results.openai_schema],
function_call={'name':'Results'}
)
results = str(Results.from_response(response))

pairs = str(results).split()
result_dict = {}
for pair in pairs:
key, value = pair.split("=")
result_dict[key] = float(value) # Convert value to float
results = result_dict
results = json.dumps(results)
if counter_callback is not None:
counter_callback(
input_tokens=response_reasoning.usage.prompt_tokens
+ response.usage.prompt_tokens,
output_tokens=response_reasoning.usage.completion_tokens
+ response.usage.completion_tokens,
response = client.chat.completions.create(
model=engine,
token_counter=count_tokens,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
n=1,
timeout=150,
stop=None,
functions=[Results.openai_schema],
function_call={'name':'Results'}
)
return results, reasoning_prompt + "////" + prediction_prompt, None, counter_callback
results = str(Results.from_response(response))

pairs = str(results).split()
result_dict = {}
for pair in pairs:
key, value = pair.split("=")
result_dict[key] = float(value) # Convert value to float
results = result_dict
results = json.dumps(results)
if counter_callback is not None:
counter_callback(
input_tokens=response_reasoning.usage.prompt_tokens
+ response.usage.prompt_tokens,
output_tokens=response_reasoning.usage.completion_tokens
+ response.usage.completion_tokens,
model=engine,
token_counter=count_tokens,
)
return results, reasoning_prompt + "////" + prediction_prompt, None, counter_callback
except Exception as e:
return f"Invalid response. The following issue was encountered: {str(e)}", "", None, None
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi
prediction_request_sme.py: bafybeiesilq7jxjzwtvhrc4m3om5fpqqzimxtkf3s3hw7l2rmfna2uhjuy
prediction_request_sme.py: bafybeiahteuasjn632fvfqt4to772yreohzuorwlvqgrj4dqcfl662lnyi
fingerprint_ignore_patterns: []
entry_point: prediction_request_sme.py
callable: run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"Moderation flagged the prompt as in violation of terms.",
prediction_prompt,
None,
counter_callback,
)
messages = [
{"role": "system", "content": sme_introduction},
Expand Down
43 changes: 22 additions & 21 deletions packages/packages.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"dev": {
"custom/valory/native_transfer_request/0.1.0": "bafybeid22vi5xtavqhq5ir2kq6nakckm3tl72wcgftsq35ak3cboyn6eea",
"custom/valory/prediction_request_claude/0.1.0": "bafybeidmtovzewf3be6wzdsoozdyin2hvq2efw233arohv243f52jzapli",
"custom/valory/prediction_request_claude/0.1.0": "bafybeifbootge455l6jrz252zd63hxsikkdntroebv5r4lknuh2vbkpwzq",
"custom/valory/openai_request/0.1.0": "bafybeigew6ukd53n3z352wmr5xu6or3id7nsqn7vb47bxs4pg4qtkmbdiu",
"custom/valory/prediction_request_embedding/0.1.0": "bafybeifdhbbxmwf4q6pjhanubgrzhy7hetupyoekjyxvnubcccqxlkaqu4",
"custom/valory/resolve_market/0.1.0": "bafybeiaag2e7rsdr3bwg6mlmfyom4vctsdapohco7z45pxhzjymepz3rya",
Expand All @@ -11,25 +11,26 @@
"custom/jhehemann/prediction_sum_url_content/0.1.0": "bafybeibgyldsxjqh6wekmpftxo66xk2ufsltigfe76ccvxsmdu3ba3w2eu",
"custom/psouranis/optimization_by_prompting/0.1.0": "bafybeieb6czcv7hwc5bsosyx67gmibt43q4uh6uozjnrfm5lic2mdtrcay",
"custom/nickcom007/sme_generation_request/0.1.0": "bafybeieqlvcmdmybq7km6ycg3h7nzlnzlcizm3pjcra67rthtddxiqatd4",
"custom/nickcom007/prediction_request_sme/0.1.0": "bafybeiadh3pypp5uwvj7ghawmrxezm6wnfxok4y2dzmqo6jkkybcq7ky3m",
"custom/nickcom007/prediction_request_sme/0.1.0": "bafybeifmcbw7odtm46xjryzr4c575tfs3sgtb6bq25ywmut33chbz3gkqa",
"custom/napthaai/resolve_market_reasoning/0.1.0": "bafybeigutlttyivlf6yxdeesclv3dwxq6h7yj3varq63b6ujno3q6ytoje",
"custom/napthaai/prediction_request_rag/0.1.0": "bafybeibp3hfeywllhmepvqyah763go4i5vtvo4wwihy6h4x7sylsjm5cam",
"custom/napthaai/prediction_request_reasoning/0.1.0": "bafybeihrq7vvup2jztclucyrveb3nlgeb2pe5afgrxglgb7ji6b5jv5vtm",
"custom/valory/prepare_tx/0.1.0": "bafybeighlfdmykwbar6wuipeo66blv2vcckxyspvw2oscsjctowly5taf4",
"custom/napthaai/prediction_request_reasoning/0.1.0": "bafybeiexotxqt3o4mmbjr7gp74hhcsbfhvldlrcqbj573oliljqxeilh44",
"custom/valory/prepare_tx/0.1.0": "bafybeibxhhdpbdd3ma2jsu76egq56v2cjrx332fwjqbzgs6uvyqjqxcru4",
"custom/valory/short_maker/0.1.0": "bafybeif63rt4lkopu3rc3l7sg6tebrrwg2lxqufjx6dx4hoda5yzax43fa",
"protocol/valory/acn_data_share/0.1.0": "bafybeih5ydonnvrwvy2ygfqgfabkr47s4yw3uqxztmwyfprulwfsoe7ipq",
"protocol/valory/websocket_client/0.1.0": "bafybeih43mnztdv3v2hetr2k3gezg7d3yj4ur7cxdvcyaqhg65e52s5sf4",
"contract/valory/agent_mech/0.1.0": "bafybeidsau5x2vjofpcdzxkg7airwkrdag65ohtxcby2ut27tfjizgnrnm",
"contract/valory/agent_registry/0.1.0": "bafybeiargayav6yiztdnwzejoejstcx4idssch2h4f5arlgtzj3tgsgfmu",
"contract/valory/hash_checkpoint/0.1.0": "bafybeigv2bceirhy72yajxzibi4a5wrcfptfbkjbzzko6pqdq2f4dzr3xa",
"connection/valory/websocket_client/0.1.0": "bafybeiflmystocxaqblhpzqlcop2vkhsknpzjx2jomohomaxamwskeokzm",
"skill/valory/contract_subscription/0.1.0": "bafybeicyugrkx5glat4p4ezwf6i7oduh26eycfie6ftd4uxrknztzl3ik4",
"skill/valory/mech_abci/0.1.0": "bafybeihzip2inpeeygbqexpx4gcjwoiteqhytlmcu7funrfil4bhdgpb74",
"skill/valory/task_submission_abci/0.1.0": "bafybeiedymlwwaipot4za5bqp7wyu5izj2t3o5gwy3mzsbywtioapc75ni",
"skill/valory/task_execution/0.1.0": "bafybeihfusmz5vrtgsbveh3hvlgsrmrfyumcxulrnrcfmoraifwr3dzxhy",
"skill/valory/mech_abci/0.1.0": "bafybeiao3zhznxdey5b4azhyjyvr7dyczxcjqm3akbc3ma2r4edhd2gcdy",
"skill/valory/task_submission_abci/0.1.0": "bafybeicjy2saxrlg5ezeissbccb57mt3jfxmn3oli754hipm5mo4nra5ri",
"skill/valory/task_execution/0.1.0": "bafybeiapj6qhwzqrr7biebnwn2e2ugqsby2ml7z25fbhmdjnf3u4yp6bem",
"skill/valory/websocket_client/0.1.0": "bafybeidwntmkk4b2ixq5454ycbkknclqx7a6vpn7aqpm2nw3duszqrxvta",
"skill/valory/subscription_abci/0.1.0": "bafybeicsxdt3mv6idkn5gyaqljvgywgbo57zim6jlpip55fqqlr5rzhsxq",
"agent/valory/mech/0.1.0": "bafybeihotgdvdxsthgb3xb2bb4ac6dtjc5lnidzlpyt6bhzmmqtl255dsa",
"service/valory/mech/0.1.0": "bafybeidau7tk5hxagj4og4nflfm5prwdijegqdfxb6nmdqnimlp3f5rrze"
"skill/valory/subscription_abci/0.1.0": "bafybeihn6lbkxemurkiefbequprq5u5zzef77wleyjs7zzpf4aq3dtf4vm",
"agent/valory/mech/0.1.0": "bafybeifcb2o44phvd4o7de56qvfdkdrmxi4tq3jcp2dvf4pfkndwtuck6m",
"service/valory/mech/0.1.0": "bafybeihac7vg75uqizjzjqws7c4zgwi3y6lisgsztbmrwsbqu4ldydpb24"
},
"third_party": {
"protocol/valory/default/1.0.0": "bafybeifqcqy5hfbnd7fjv4mqdjrtujh2vx3p2xhe33y67zoxa6ph7wdpaq",
Expand All @@ -41,21 +42,21 @@
"protocol/valory/acn/1.1.0": "bafybeidluaoeakae3exseupaea4i3yvvk5vivyt227xshjlffywwxzcxqe",
"protocol/valory/ipfs/0.1.0": "bafybeiftxi2qhreewgsc5wevogi7yc5g6hbcbo4uiuaibauhv3nhfcdtvm",
"protocol/valory/tendermint/0.1.0": "bafybeig4mi3vmlv5zpbjbfuzcgida6j5f2nhrpedxicmrrfjweqc5r7cra",
"contract/valory/service_registry/0.1.0": "bafybeiby5x4wfdywlenmoudbykdxohpq2nifqxfep5niqgxrjyrekyahzy",
"contract/valory/gnosis_safe_proxy_factory/0.1.0": "bafybeie6ynnoavvk2fpbn426nlp32sxrj7pz5esgebtlezy4tmx5gjretm",
"contract/valory/gnosis_safe/0.1.0": "bafybeictjc7saviboxbsdcey3trvokrgo7uoh76mcrxecxhlvcrp47aqg4",
"contract/valory/service_registry/0.1.0": "bafybeicbxmbzt757lbmyh6762lrkcrp3oeum6dk3z7pvosixasifsk6xlm",
"contract/valory/gnosis_safe_proxy_factory/0.1.0": "bafybeib6podeifufgmawvicm3xyz3uaplbcrsptjzz4unpseh7qtcpar74",
"contract/valory/gnosis_safe/0.1.0": "bafybeibq77mgzhyb23blf2eqmia3kc6io5karedfzhntvpcebeqdzrgyqa",
"contract/valory/multisend/0.1.0": "bafybeig5byt5urg2d2bsecufxe5ql7f4mezg3mekfleeh32nmuusx66p4y",
"connection/valory/http_client/0.23.0": "bafybeih5vzo22p2umhqo52nzluaanxx7kejvvpcpdsrdymckkyvmsim6gm",
"connection/valory/abci/0.1.0": "bafybeifbnhe4f2bll3a5o3hqji3dqx4soov7hr266rdz5vunxgzo5hggbq",
"connection/valory/ipfs/0.1.0": "bafybeiflaxrnepfn4hcnq5pieuc7ki7d422y3iqb54lv4tpgs7oywnuhhq",
"connection/valory/abci/0.1.0": "bafybeiclexb6cnsog5yjz2qtvqyfnf7x5m7tpp56hblhk3pbocbvgjzhze",
"connection/valory/ipfs/0.1.0": "bafybeihndk6hohj3yncgrye5pw7b7w2kztj3avby5u5mfk2fpjh7hqphii",
"connection/valory/ledger/0.19.0": "bafybeic3ft7l7ca3qgnderm4xupsfmyoihgi27ukotnz7b5hdczla2enya",
"connection/valory/p2p_libp2p_client/0.1.0": "bafybeid3xg5k2ol5adflqloy75ibgljmol6xsvzvezebsg7oudxeeolz7e",
"connection/valory/http_server/0.22.0": "bafybeihpgu56ovmq4npazdbh6y6ru5i7zuv6wvdglpxavsckyih56smu7m",
"skill/valory/transaction_settlement_abci/0.1.0": "bafybeid57tozt5f3kgzmu22nbr3c3oy4p7bi2bu66rqsgnlylq6xgh2ixe",
"skill/valory/termination_abci/0.1.0": "bafybeie6h7j4hyhgj2wte64n3xyudxq4pgqcqjmslxi5tff4mb6vce2tay",
"skill/valory/abstract_round_abci/0.1.0": "bafybeigjrepaqpb3m7zunmt4hryos4vto4yyj3u6iyofdb2fotwho3bqvm",
"skill/valory/reset_pause_abci/0.1.0": "bafybeicm7onl72rfnn33pbvzwjpkl5gafeieyobfcnyresxz7kunjwmqea",
"skill/valory/registration_abci/0.1.0": "bafybeif3ln6eg53ebrfe6uicjew4uqp2ynyrcxkw5wi4jm3ixqv3ykte4a",
"skill/valory/abstract_abci/0.1.0": "bafybeihljirk3d4rgvmx2nmz3p2mp27iwh2o5euce5gccwjwrpawyjzuaq"
"skill/valory/transaction_settlement_abci/0.1.0": "bafybeibnsqrkzfm2sbjvtn7a7bsv7irikqv653nn7lfpc7mi43zijrhvom",
"skill/valory/termination_abci/0.1.0": "bafybeiejcthb2uvpjxgon5wxdzua3tm6ud66nrbrctzocwmbjuwwwx2rg4",
"skill/valory/abstract_round_abci/0.1.0": "bafybeieehcc2jqker6jmflxc4nnfcdc7k3s5zs2zbruwenb363dgt4xhai",
"skill/valory/reset_pause_abci/0.1.0": "bafybeiatxhn6r7wd35frh4h4i4twkpk5kkadbbnfgwnz6vampsspt5yy2i",
"skill/valory/registration_abci/0.1.0": "bafybeigpv56crn7hz6cqpicou7ulpz3by5uphcy6vm2rkqe5gkwuybcpe4",
"skill/valory/abstract_abci/0.1.0": "bafybeihat4giyc4bz6zopvahcj4iw53356pbtwfn7p4d5yflwly2qhahum"
}
}
Loading

0 comments on commit 73c0d88

Please sign in to comment.