-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5ed39ce
commit 129fc90
Showing
6 changed files
with
155 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,33 @@ | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
import openai | ||
import time | ||
import os | ||
import re | ||
import random | ||
|
||
from utils import * | ||
from llm_settings.openai_models import * | ||
from llm_settings.gemini_models import * | ||
from llm_settings.deepinfra_models import * | ||
|
||
openai.api_key = api_key | ||
|
||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) | ||
def chat( | ||
model, # gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301 | ||
messages, # [{"role": "system"/"user"/"assistant", "content": "Hello!", "name": "example"}] | ||
temperature=temperature, # [0, 2]: Lower values -> more focused and deterministic; Higher values -> more random. | ||
n=1, # Chat completion choices to generate for each input message. | ||
max_tokens=1024, # The maximum number of tokens to generate in the chat completion. | ||
delay=delay_time # Seconds to sleep after each request. | ||
): | ||
time.sleep(delay) | ||
|
||
response = openai.ChatCompletion.create( | ||
model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
n=n, | ||
max_tokens=max_tokens | ||
) | ||
|
||
if n == 1: | ||
return response['choices'][0]['message']['content'] | ||
else: | ||
return [i['message']['content'] for i in response['choices']] | ||
|
||
|
||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) | ||
def completion( | ||
model, # text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001 | ||
prompt, # The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. | ||
temperature=temperature, # [0, 2]: Lower values -> more focused and deterministic; Higher values -> more random. | ||
n=1, # Completions to generate for each prompt. | ||
max_tokens=1024, # The maximum number of tokens to generate in the chat completion. | ||
delay=delay_time # Seconds to sleep after each request. | ||
): | ||
time.sleep(delay) | ||
|
||
response = openai.Completion.create( | ||
model=model, | ||
prompt=prompt, | ||
temperature=temperature, | ||
n=n, | ||
max_tokens=max_tokens | ||
) | ||
|
||
if n == 1: | ||
return response['choices'][0]['text'] | ||
else: | ||
response = response['choices'] | ||
response.sort(key=lambda x: x['index']) | ||
return [i['text'] for i in response['choices']] | ||
|
||
def print_prompt(inputs, response): | ||
os.makedirs("records", exist_ok=True) | ||
with open(f"records/records.txt", 'a') as f: | ||
f.write(f"{inputs}\n----\n") | ||
f.write(f"{response}\n====\n") | ||
return | ||
|
||
def gpt_request(model, inputs): | ||
json_format = r'({.*})' | ||
|
||
def llm_request(model, inputs): | ||
if model.startswith("gpt"): | ||
response = gpt_chat(model, inputs).strip() | ||
|
||
elif model.startswith("gemini"): | ||
response = gemini_chat(model, inputs).strip() | ||
|
||
elif model.startswith("meta-llama"): | ||
response = deepinfra_chat(model, inputs).strip() | ||
|
||
if model == 'text-davinci-003': | ||
response = completion(model, inputs).strip() | ||
print_prompt(inputs, response) | ||
match = re.search(json_format, response, re.DOTALL) | ||
return str(match) | ||
elif model in ['gpt-3.5-turbo', 'gpt-4']: | ||
response = chat(model, inputs).strip() | ||
print_prompt(inputs, response) | ||
match = re.search(json_format, response, re.DOTALL) | ||
if match: | ||
return match.group(1).strip() | ||
else: | ||
"" | ||
|
||
else: | ||
raise ValueError("The model is not supported or does not exist.") | ||
|
||
print_prompt(inputs, response) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
import time | ||
from openai import OpenAI | ||
|
||
from utils import * | ||
from global_functions import * | ||
|
||
openai = OpenAI( | ||
api_key=infradeep_api_key, | ||
base_url="https://api.deepinfra.com/v1/openai", | ||
) | ||
|
||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) | ||
def deepinfra_chat( | ||
model, # meta-llama/Meta-Llama-3.1-70B-Instruct, mistralai/Mixtral-8x7B-Instruct-v0.1, Qwen/Qwen2-72B-Instruct | ||
prompt, # The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. | ||
temperature=temperature, # [0, 2]: Lower values -> more focused and deterministic; Higher values -> more random. | ||
n=1, # Completions to generate for each prompt. | ||
max_tokens=1024, # The maximum number of tokens to generate in the chat completion. | ||
delay=delay_time # Seconds to sleep after each request. | ||
): | ||
time.sleep(delay) | ||
|
||
response = openai.chat.completions.create( | ||
model=model, | ||
messages=prompt, | ||
temperature=temperature, | ||
stream=False, | ||
) | ||
|
||
return extract_json_str(response.choices[0].message.content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
import time | ||
import google.generativeai as genai | ||
|
||
from utils import * | ||
from global_functions import * | ||
|
||
genai.configure(api_key=google_api_key) | ||
|
||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) | ||
def gemini_chat( | ||
model, # gemini-1.0-pro, gemini-1.0-pro-001, gemini-1.0-pro-latest, gemini-1.0-pro-vision-latest, gemini-pro, gemini-pro-vision | ||
messages, # [{'role': 'user', 'parts': "In one sentence, explain how a computer works to a young child."}, {'role': "model', 'parts': "A computer is like a very smart machine that can understand and follow our instructions, help us with our work, and even play games with us!"} | ||
temperature=temperature, # [0, 2]: Lower values -> more focused and deterministic; Higher values -> more random. | ||
n=1, # Chat response choices to generate for each input message. | ||
max_tokens=1024, # The maximum number of tokens to generate in the chat completion. | ||
delay=delay_time # Seconds to sleep after each request. | ||
): | ||
time.sleep(delay) | ||
model = genai.GenerativeModel(model) | ||
response = model.generate_content( | ||
messages, | ||
generation_config=genai.types.GenerationConfig( | ||
# Only one candidate for now. | ||
candidate_count=n, | ||
# stop_sequences=['x'], | ||
max_output_tokens=max_tokens, | ||
temperature=temperature) | ||
) | ||
|
||
return extract_json_str(response.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
import time | ||
from openai import OpenAI | ||
|
||
from utils import * | ||
from global_functions import * | ||
|
||
openai = OpenAI(api_key=openai_api_key) | ||
|
||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) | ||
def gpt_chat( | ||
model, # gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301 | ||
messages, # [{"role": "system"/"user"/"assistant", "content": "Hello!", "name": "example"}] | ||
temperature=temperature, # [0, 2]: Lower values -> more focused and deterministic; Higher values -> more random. | ||
n=1, # Chat completion choices to generate for each input message. | ||
max_tokens=1024, # The maximum number of tokens to generate in the chat completion. | ||
delay=delay_time # Seconds to sleep after each request. | ||
): | ||
time.sleep(delay) | ||
|
||
response = openai.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
n=n, | ||
max_tokens=max_tokens | ||
) | ||
|
||
return extract_json_str(response.choices[0].message.content) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters