diff --git a/src/deploy/inference_dialog_server.py b/src/deploy/inference_dialog_server.py index 45aea8a..feb4550 100644 --- a/src/deploy/inference_dialog_server.py +++ b/src/deploy/inference_dialog_server.py @@ -21,6 +21,7 @@ from inference.utils import predict_by_similarity from tqdm import tqdm from deploy.utils import change_format +from gpt.utils import get_all_api_json, correct_pred import logging from datetime import datetime @@ -263,6 +264,7 @@ def __init__(self): self.centroids = pickle.load(f) print('==>chitchat vectorizer loaded!') self.retrieve_query_mode = "similar" + self.all_apis, self.all_apis_json = get_all_api_json(f"./data/standard_process/{self.LIB}/API_init.json") print("Server ready") def load_bert_model(self, load_mode='unfinetuned_bert'): if load_mode=='unfinetuned_bert': @@ -313,12 +315,7 @@ def reset_lib(self, lib_name): self.executor.execute_api_call(f"np.seterr(under='ignore')", "import") self.executor.execute_api_call(f"import warnings", "import") self.executor.execute_api_call(f"warnings.filterwarnings('ignore')", "import") - #end_of_docstring_summary = re.compile(r'[{}\n]+'.format(re.escape(punctuation))) - #all_apis = {x: end_of_docstring_summary.split(self.API_composite[x]['Docstring'])[0].strip() for x in self.API_composite} #used before 231219 - # 231219 changed API description - all_apis = {x: self.API_composite[x]['description'] for x in self.API_composite} - all_apis = list(all_apis.items()) - self.description_json = {i[0]:i[1] for i in all_apis} + self.all_apis, self.all_apis_json = get_all_api_json(f"./data/standard_process/{lib_name}/API_init.json") print('==>Successfully loading model!') print('loading model cost: %s s', str(time.time()-t1)) reset_result = "Success" @@ -604,7 +601,6 @@ def load_state(self, session_id): self.__dict__.update(state) print("State loaded from %s", file_name) def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=True,session_id=""): - self.indexxxx = 2 #if session_id != self.session_id: if True: @@ -663,42 +659,58 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T retrieved_names = self.retrieve_names(user_input) print("retrieved_names: %s", retrieved_names) # produce prompt - description_jsons = {} - try: - for i in retrieved_names: - description_jsons[i] = self.description_json[i] - except Exception as e: - [callback.on_agent_action(block_id="log-" + str(self.indexxxx), task=f"The retrieved names is not in API_composite, please double check",task_title="API json Error ",) for callback in self.callbacks] - self.indexxxx += 1 if self.retrieve_query_mode=='similar': instruction_shot_example = self.retriever.retrieve_similar_queries(user_input, shot_k=5) else: sampled_shuffled = random.sample(self.retriever.shuffled_data, 5) instruction_shot_example = "".join(["\nInstruction: " + ex['query'] + "\nFunction: " + ex['gold'] for ex in sampled_shuffled]) + similar_queries = "" + shot_k=5 # 5 examples + idx = 0 + for iii in sampled_shuffled: + instruction = iii['query'] + tmp_retrieved_api_list = self.retriever.retrieving(instruction, top_k=top_k) + # ensure the order won't affect performance + tmp_retrieved_api_list = random.sample(tmp_retrieved_api_list, len(tmp_retrieved_api_list)) + # ensure the example is correct + if iii['gold'] in tmp_retrieved_api_list: + if idxAsk GPT: %s\n==>GPT response: %s', api_predict_prompt, response) # hack for if GPT answers this or that - response = response.split(',')[0].split("(")[0].split(' or ')[0] + """response = response.split(',')[0].split("(")[0].split(' or ')[0] response = response.replace('{','').replace('}','').replace('"','').replace("'",'') - response = response.split(':')[0]# for robustness, sometimes gpt will return api:description - self.description_json[response] + response = response.split(':')[0]# for robustness, sometimes gpt will return api:description""" + response = correct_pred(response, self.LIB) + response = response.strip() + self.all_apis_json[response] self.predicted_api_name = response success = True break except Exception as e: + print('error during api prediction: ', e) pass #return if not success: [callback.on_tool_start() for callback in self.callbacks] [callback.on_tool_end() for callback in self.callbacks] - [callback.on_agent_action(block_id="log-" + str(self.indexxxx),task=f"GPT can not return valid API name prediction, please redesign your prompt.",task_title="GPT predict Error ",) for callback in self.callbacks] + [callback.on_agent_action(block_id="log-" + str(self.indexxxx),task=f"GPT can not return valid API name prediction, please redesign your prompt.",task_title="GPT predict Error",) for callback in self.callbacks] self.indexxxx += 1 return print(f'length of ambiguous api list: {len(self.ambiguous_api)}') diff --git a/src/gpt/utils.py b/src/gpt/utils.py index a03d47d..b571327 100644 --- a/src/gpt/utils.py +++ b/src/gpt/utils.py @@ -168,6 +168,7 @@ def correct_pred(pred, lib_name): ans = pred[pred.find(lib_name):] else: ans = pred + ans = ans.strip() return ans def generate_custom_val_indices(api_ranges):