forked from huchenxucs/ChatDB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat.py
151 lines (121 loc) · 6.59 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import time
import openai
from dotenv import load_dotenv
from config import Config
import token_counter
from chatgpt import create_chat_completion
from logger import logger
import logging
cfg = Config()
def create_chat_message(role, content):
"""
Create a chat message with the given role and content.
Args:
role (str): The role of the message sender, e.g., "system", "user", or "assistant".
content (str): The content of the message.
Returns:
dict: A dictionary containing the role and content of the message.
"""
return {"role": role, "content": content}
def generate_context(prompt, relevant_memory, full_message_history, model):
current_context = [
create_chat_message(
"system", prompt),
# create_chat_message(
# "system", f"The current time and date is {time.strftime('%c')}"),
# create_chat_message(
# "system", f"This reminds you of these events from your past:\n{relevant_memory}\n\n"),
]
# Add messages from the full message history until we reach the token limit
next_message_to_add_index = len(full_message_history) - 1
insertion_index = len(current_context)
# Count the currently used tokens
current_tokens_used = token_counter.count_message_tokens(current_context, model)
return next_message_to_add_index, current_tokens_used, insertion_index, current_context
# TODO: Change debug from hardcode to argument
def chat_with_ai(
prompt,
user_input,
full_message_history,
permanent_memory,
token_limit):
"""Interact with the OpenAI API, sending the prompt, user input, message history, and permanent memory."""
while True:
try:
"""
Interact with the OpenAI API, sending the prompt, user input, message history, and permanent memory.
Args:
prompt (str): The prompt explaining the rules to the AI.
user_input (str): The input from the user.
full_message_history (list): The list of all messages sent between the user and the AI.
permanent_memory (Obj): The memory object containing the permanent memory.
token_limit (int): The maximum number of tokens allowed in the API call.
Returns:
str: The AI's response.
"""
model = cfg.fast_llm_model # TODO: Change model from hardcode to argument
# Reserve 1000 tokens for the response
send_token_limit = token_limit - 1000
# relevant_memory = '' if len(full_message_history) ==0 else permanent_memory.get_relevant(str(full_message_history[-9:]), 10)
# logger.debug(f'Memory Stats: {permanent_memory.get_stats()}')
relevant_memory = None
next_message_to_add_index, current_tokens_used, insertion_index, current_context = generate_context(
prompt, relevant_memory, full_message_history, model)
# while current_tokens_used > 2500:
# # remove memories until we are under 2500 tokens
# relevant_memory = relevant_memory[1:]
# next_message_to_add_index, current_tokens_used, insertion_index, current_context = generate_context(
# prompt, relevant_memory, full_message_history, model)
current_tokens_used += token_counter.count_message_tokens([create_chat_message("user", user_input)], model) # Account for user input (appended later)
while next_message_to_add_index >= 0:
# print (f"CURRENT TOKENS USED: {current_tokens_used}")
message_to_add = full_message_history[next_message_to_add_index]
tokens_to_add = token_counter.count_message_tokens([message_to_add], model)
if current_tokens_used + tokens_to_add > send_token_limit:
break
# Add the most recent message to the start of the current context, after the two system prompts.
current_context.insert(insertion_index, full_message_history[next_message_to_add_index])
# Count the currently used tokens
current_tokens_used += tokens_to_add
# Move to the next most recent message in the full message history
next_message_to_add_index -= 1
# Append user input, the length of this is accounted for above
current_context.extend([create_chat_message("user", user_input)])
# Calculate remaining tokens
tokens_remaining = token_limit - current_tokens_used
# assert tokens_remaining >= 0, "Tokens remaining is negative. This should never happen, please submit a bug report at https://www.github.com/Torantulino/Auto-GPT"
# Debug print the current context
# logger.debug(f"Token limit: {token_limit}")
# logger.debug(f"Send Token Count: {current_tokens_used}")
# logger.debug(f"Tokens remaining for response: {tokens_remaining}")
# logger.debug("------------ CONTEXT SENT TO AI ---------------")
# for message in current_context:
# # Skip printing the prompt
# # if message["role"] == "system" and message["content"] == prompt:
# # continue
# logger.debug(f"{message['role'].capitalize()}: {message['content']}")
# logger.debug("")
# logger.debug("----------- END OF CONTEXT ----------------")
# TODO: use a model defined elsewhere, so that model can contain temperature and other settings we care about
# print(current_context)
assistant_reply = create_chat_completion(
model=model,
messages=current_context,
max_tokens=tokens_remaining,
)
# Update full message history
full_message_history.append(create_chat_message("user", user_input))
full_message_history.append(create_chat_message("assistant", assistant_reply))
# logger.debug(f"{full_message_history[-1]['role'].capitalize()}: {full_message_history[-1]['content']}")
# logger.debug("----------- END OF RESPONSE ----------------")
return assistant_reply
except openai.error.RateLimitError:
# TODO: When we switch to langchain, this is built in
print("Error: ", "API Rate Limit Reached. Waiting 10 seconds...")
time.sleep(10)
if __name__ == '__main__':
cfg.set_debug_mode(False)
full_msg_history = []
while True:
user_inp = input()
chat_with_ai("You are ChatDB.", user_inp, full_msg_history, None, 1100)