Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update prediction-request-claude for Claude 3 #205

Merged
merged 7 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/packages.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"dev": {
"custom/valory/native_transfer_request/0.1.0": "bafybeid22vi5xtavqhq5ir2kq6nakckm3tl72wcgftsq35ak3cboyn6eea",
"custom/valory/prediction_request_claude/0.1.0": "bafybeigpuw4z2xq6vxdsi27cwqomdvbbz364sgzzzkofjxowmd6shwrcoa",
"custom/valory/prediction_request_claude/0.1.0": "bafybeigzihxu3natj6twnnw4codojvjs33t3tqqjibtqltnka5ilpie7qi",
"custom/valory/openai_request/0.1.0": "bafybeigew6ukd53n3z352wmr5xu6or3id7nsqn7vb47bxs4pg4qtkmbdiu",
"custom/valory/prediction_request_embedding/0.1.0": "bafybeifdhbbxmwf4q6pjhanubgrzhy7hetupyoekjyxvnubcccqxlkaqu4",
"custom/valory/resolve_market/0.1.0": "bafybeiaag2e7rsdr3bwg6mlmfyom4vctsdapohco7z45pxhzjymepz3rya",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi
prediction_request_claude.py: bafybeievmbi5ck63om47nimumdqzflhcpqkqymmtauomoouszh2gs23coi
prediction_request_claude.py: bafybeiaamsdbgmcpjadljsfvopcrj67wmyc46sbph3zcdifi5mlqr4746u
fingerprint_ignore_patterns: []
entry_point: prediction_request_claude.py
callable: run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,38 @@

import anthropic
import requests
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
from readability import Document
from tiktoken import encoding_for_model
from markdownify import markdownify as md
from googleapiclient.discovery import build

NUM_URLS_EXTRACT = 5
DEFAULT_NUM_WORDS = 300
DEFAULT_OPENAI_SETTINGS = {
"max_tokens": 500,
"temperature": 0.,
DEFAULT_CLAUDE_SETTINGS = {
"max_tokens": 1000,
"temperature": 0,
}
MAX_TOKENS = {
'claude-2': 200_0000,
'claude-2.1': 200_0000,
'claude-3-haiku-20240307': 200_0000,
'claude-3-sonnet-20240229': 200_0000,
'claude-3-opus-20240229': 200_0000,
}
ALLOWED_TOOLS = [
"claude-prediction-offline",
"claude-prediction-online",
]
ALLOWED_MODELS = [
"claude-2",
"claude-2.1",
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
]
TOOL_TO_ENGINE = {
"claude-prediction-offline": "claude-2",
"claude-prediction-online": "claude-2",
"claude-prediction-offline": "claude-3-sonnet-20240229",
"claude-prediction-online": "claude-3-sonnet-20240229",
}

PREDICTION_PROMPT = """
Expand Down Expand Up @@ -124,13 +138,16 @@
* This is correct:"{{\n \"queries\": [\"term1\", \"term2\"]}}"
"""

SYSTEM_PROMPT = """You are a world class algorithm for generating structured output from a given input."""

ASSISTANT_TEXT = "```json"
STOP_SEQUENCES = ["```"]


def count_tokens(text: str, model: str) -> int:
"""Count the number of tokens in a text."""
return anthropic.Anthropic().count_tokens(text)
enc = encoding_for_model(model)
return len(enc.encode(text))

def search_google(query: str, api_key: str, engine: str, num: int = 3) -> List[str]:
service = build("customsearch", "v1", developerKey=api_key)
Expand Down Expand Up @@ -228,27 +245,32 @@ def extract_texts(urls: List[str], num_words: int = 300) -> List[str]:


def fetch_additional_information(
client: anthropic.Anthropic,
prompt: str,
engine: str,
anthropic: Anthropic,
google_api_key: Optional[str],
google_engine: Optional[str],
num_urls: Optional[int],
num_words: Optional[int],
counter_callback: Optional[Callable] = None,
temperature: Optional[float] = DEFAULT_CLAUDE_SETTINGS["temperature"],
max_tokens: Optional[int] = DEFAULT_CLAUDE_SETTINGS["max_tokens"],
source_links: Optional[List[str]] = None,
) -> str:
"""Fetch additional information."""
url_query_prompt = URL_QUERY_PROMPT.format(user_prompt=prompt)
url_query_prompt = f"{HUMAN_PROMPT}{url_query_prompt}{AI_PROMPT}{ASSISTANT_TEXT}"
completion = anthropic.completions.create(
messages = [
{"role": "user", "content": url_query_prompt},
]
response = client.messages.create(
model=engine,
max_tokens_to_sample=300,
prompt=url_query_prompt,
stop_sequences=STOP_SEQUENCES,
messages=messages,
system=SYSTEM_PROMPT,
temperature=temperature,
max_tokens=max_tokens,
)
try:
json_data = json.loads(completion.completion)
json_data = json.loads(response.content[0].text)
except json.JSONDecodeError:
json_data = {}

Expand Down Expand Up @@ -288,25 +310,33 @@ def fetch_additional_information(
def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
"""Run the task"""
tool = kwargs["tool"]
model = kwargs.get("model", TOOL_TO_ENGINE[tool])
prompt = kwargs["prompt"]
anthropic = Anthropic(api_key=kwargs["api_keys"]["anthropic"])
max_tokens = kwargs.get("max_tokens", DEFAULT_CLAUDE_SETTINGS["max_tokens"])
temperature = kwargs.get("temperature", DEFAULT_CLAUDE_SETTINGS["temperature"])
num_urls = kwargs.get("num_urls", NUM_URLS_EXTRACT)
num_words = kwargs.get("num_words", DEFAULT_NUM_WORDS)
counter_callback = kwargs.get("counter_callback", None)
api_keys = kwargs.get("api_keys", {})
google_api_key = api_keys.get("google_api_key", None)
google_engine_id = api_keys.get("google_engine_id", None)
client = anthropic.Anthropic(api_key=api_keys["anthropic"])

# Make sure the model is supported
if model not in ALLOWED_MODELS:
raise ValueError(f"Model {model} not supported.")

if tool not in ALLOWED_TOOLS:
raise ValueError(f"Tool {tool} is not supported.")

engine = TOOL_TO_ENGINE[tool]
engine = kwargs.get("model", TOOL_TO_ENGINE[tool])
print(f"ENGINE: {engine}")

if tool == "claude-prediction-online":
additional_information, counter_callback = fetch_additional_information(
client=client,
prompt=prompt,
engine=engine,
anthropic=anthropic,
google_api_key=google_api_key,
google_engine=google_engine_id,
num_urls=num_urls,
Expand All @@ -316,23 +346,34 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]:
)
else:
additional_information = ""

# Make the prediction
prediction_prompt = PREDICTION_PROMPT.format(
user_prompt=prompt, additional_information=additional_information
)
prediction_prompt = f"{HUMAN_PROMPT}{prediction_prompt}{AI_PROMPT}{ASSISTANT_TEXT}"

completion = anthropic.completions.create(
messages = [
{
"role": "user",
"content": prediction_prompt,
},
]

response_prediction = client.messages.create(
model=engine,
max_tokens_to_sample=300,
prompt=prediction_prompt,
stop_sequences=STOP_SEQUENCES,
messages=messages,
system=SYSTEM_PROMPT,
temperature=temperature,
max_tokens=max_tokens,
)
if counter_callback is not None:
prediction = response_prediction.content[0].text

if counter_callback:
counter_callback(
input_tokens=response_prediction.usage.input_tokens,
output_tokens=response_prediction.usage.output_tokens,
model=engine,
input_prompt=prediction_prompt,
output_prompt=completion.completion,
token_counter=count_tokens,
)

return completion.completion, prediction_prompt, None, counter_callback
return prediction, prediction_prompt, None, counter_callback
1 change: 0 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ deps =
eth-utils==2.2.0
eth-abi==4.0.0
pycryptodome==3.18.0
openapi-core==0.15.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

openapi-spec-validator<0.5.0,>=0.4.0
anthropic==0.3.11
langchain==0.0.303
Expand Down
Loading