Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
DoraDong-2023 committed Mar 15, 2024
1 parent 302d78f commit df26d83
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
50 changes: 31 additions & 19 deletions src/deploy/inference_dialog_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from inference.utils import predict_by_similarity
from tqdm import tqdm
from deploy.utils import change_format
from gpt.utils import get_all_api_json, correct_pred

import logging
from datetime import datetime
Expand Down Expand Up @@ -263,6 +264,7 @@ def __init__(self):
self.centroids = pickle.load(f)
print('==>chitchat vectorizer loaded!')
self.retrieve_query_mode = "similar"
self.all_apis, self.all_apis_json = get_all_api_json(f"./data/standard_process/{self.LIB}/API_init.json")
print("Server ready")
def load_bert_model(self, load_mode='unfinetuned_bert'):
if load_mode=='unfinetuned_bert':
Expand Down Expand Up @@ -313,12 +315,7 @@ def reset_lib(self, lib_name):
self.executor.execute_api_call(f"np.seterr(under='ignore')", "import")
self.executor.execute_api_call(f"import warnings", "import")
self.executor.execute_api_call(f"warnings.filterwarnings('ignore')", "import")
#end_of_docstring_summary = re.compile(r'[{}\n]+'.format(re.escape(punctuation)))
#all_apis = {x: end_of_docstring_summary.split(self.API_composite[x]['Docstring'])[0].strip() for x in self.API_composite} #used before 231219
# 231219 changed API description
all_apis = {x: self.API_composite[x]['description'] for x in self.API_composite}
all_apis = list(all_apis.items())
self.description_json = {i[0]:i[1] for i in all_apis}
self.all_apis, self.all_apis_json = get_all_api_json(f"./data/standard_process/{lib_name}/API_init.json")
print('==>Successfully loading model!')
print('loading model cost: %s s', str(time.time()-t1))
reset_result = "Success"
Expand Down Expand Up @@ -604,7 +601,6 @@ def load_state(self, session_id):
self.__dict__.update(state)
print("State loaded from %s", file_name)
def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=True,session_id=""):

self.indexxxx = 2
#if session_id != self.session_id:
if True:
Expand Down Expand Up @@ -663,42 +659,58 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T
retrieved_names = self.retrieve_names(user_input)
print("retrieved_names: %s", retrieved_names)
# produce prompt
description_jsons = {}
try:
for i in retrieved_names:
description_jsons[i] = self.description_json[i]
except Exception as e:
[callback.on_agent_action(block_id="log-" + str(self.indexxxx), task=f"The retrieved names is not in API_composite, please double check",task_title="API json Error ",) for callback in self.callbacks]
self.indexxxx += 1
if self.retrieve_query_mode=='similar':
instruction_shot_example = self.retriever.retrieve_similar_queries(user_input, shot_k=5)
else:
sampled_shuffled = random.sample(self.retriever.shuffled_data, 5)
instruction_shot_example = "".join(["\nInstruction: " + ex['query'] + "\nFunction: " + ex['gold'] for ex in sampled_shuffled])
similar_queries = ""
shot_k=5 # 5 examples
idx = 0
for iii in sampled_shuffled:
instruction = iii['query']
tmp_retrieved_api_list = self.retriever.retrieving(instruction, top_k=top_k)
# ensure the order won't affect performance
tmp_retrieved_api_list = random.sample(tmp_retrieved_api_list, len(tmp_retrieved_api_list))
# ensure the example is correct
if iii['gold'] in tmp_retrieved_api_list:
if idx<shot_k:
idx+=1
# only retain shot_k number of sampled_shuffled
tmp_str = "Instruction: " + instruction + "\nFunction: [" + iii['gold'] + "]"
new_function_candidates = [f"{i}:{api}, description: "+self.all_apis_json[api].replace('\n',' ') for i, api in enumerate(tmp_retrieved_api_list)]
similar_queries += "function candidates:\n" + "\n".join(new_function_candidates) + '\n' + tmp_str + "\n---\n"
instruction_shot_example = similar_queries
# 240315: substitute prompt
from gpt.utils import get_retrieved_prompt, get_nonretrieved_prompt
api_predict_init_prompt = get_retrieved_prompt()
api_predict_prompt = api_predict_init_prompt.format(query=user_input, retrieved_apis=json.dumps(description_jsons), similar_queries=instruction_shot_example)
retrieved_apis_prepare = ""
for idx, api in enumerate(retrieved_names):
retrieved_apis_prepare+=f"{idx}:" + api+", description: "+self.all_apis_json[api].replace('\n',' ')+"\n"
api_predict_prompt = api_predict_init_prompt.format(query=user_input, retrieved_apis=retrieved_apis_prepare, similar_queries=instruction_shot_example)
success = False
for attempt in range(3):
try:
response, _ = LLM_response(self.llm, self.tokenizer, api_predict_prompt, history=[], kwargs={}) # llm
print(f'==>Ask GPT: %s\n==>GPT response: %s', api_predict_prompt, response)
# hack for if GPT answers this or that
response = response.split(',')[0].split("(")[0].split(' or ')[0]
"""response = response.split(',')[0].split("(")[0].split(' or ')[0]
response = response.replace('{','').replace('}','').replace('"','').replace("'",'')
response = response.split(':')[0]# for robustness, sometimes gpt will return api:description
self.description_json[response]
response = response.split(':')[0]# for robustness, sometimes gpt will return api:description"""
response = correct_pred(response, self.LIB)
response = response.strip()
self.all_apis_json[response]
self.predicted_api_name = response
success = True
break
except Exception as e:
print('error during api prediction: ', e)
pass
#return
if not success:
[callback.on_tool_start() for callback in self.callbacks]
[callback.on_tool_end() for callback in self.callbacks]
[callback.on_agent_action(block_id="log-" + str(self.indexxxx),task=f"GPT can not return valid API name prediction, please redesign your prompt.",task_title="GPT predict Error ",) for callback in self.callbacks]
[callback.on_agent_action(block_id="log-" + str(self.indexxxx),task=f"GPT can not return valid API name prediction, please redesign your prompt.",task_title="GPT predict Error",) for callback in self.callbacks]
self.indexxxx += 1
return
print(f'length of ambiguous api list: {len(self.ambiguous_api)}')
Expand Down
1 change: 1 addition & 0 deletions src/gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def correct_pred(pred, lib_name):
ans = pred[pred.find(lib_name):]
else:
ans = pred
ans = ans.strip()
return ans

def generate_custom_val_indices(api_ranges):
Expand Down

0 comments on commit df26d83

Please sign in to comment.