diff --git a/src/deploy/model.py b/src/deploy/model.py index b440cca..fa1ea9f 100644 --- a/src/deploy/model.py +++ b/src/deploy/model.py @@ -20,6 +20,46 @@ from ..deploy.utils import basic_types, generate_api_calling, download_file_from_google_drive, download_data, save_decoded_file, correct_bool_values, convert_bool_values, infer, dataframe_to_markdown, convert_image_to_base64, change_format, special_types, io_types, io_param_names from ..models.dialog_classifier import Dialog_Gaussian_classification from ..inference.param_count_acc import predict_parameters +from sentence_transformers import SentenceTransformer, util +from nltk.corpus import stopwords +import nltk + +nltk.download('stopwords') +stop_words = set(stopwords.words('english')) + +def remove_consecutive_duplicates(code: str) -> str: + lines = code.split('\n') + unique_lines = [] + for i in range(len(lines)): + if i == 0 or lines[i] != lines[i-1]: + unique_lines.append(lines[i]) + return '\n'.join(unique_lines) + +def highlight_keywords(user_query, api_descriptions, threshold=0.6): + # Load pre-trained Sentence-BERT model + model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') + # Tokenize the query and filter out stop words + query_words = [word for word in user_query.split() if word.lower() not in stop_words] + # Compute embeddings for the query words + query_embeddings = model.encode(query_words, convert_to_tensor=True) + highlighted_descriptions = [] + for description in api_descriptions: + words = description.split() + keywords = [] + for word in words: + if word.lower() in stop_words: + keywords.append(word) + continue + word_embedding = model.encode(word, convert_to_tensor=True) + similarities = [util.pytorch_cos_sim(word_embedding, query_embedding).item() for query_embedding in query_embeddings] + max_similarity = max(similarities) + if max_similarity > threshold: # Threshold for highlighting + #keywords.append(f"**{word}**") + keywords.append(f'{word}') + else: + keywords.append(word) + highlighted_descriptions.append(" ".join(keywords)) + return highlighted_descriptions def get_all_api_calls(tutorials): all_api_calls = {} @@ -81,29 +121,24 @@ def color_text(text, color): def label_sentence(sentence, parameters_dict): import re - colors = ['red'] # , 'purple', 'blue', 'green', 'orange' color_map = {} color_index = 0 - def get_color(term): nonlocal color_index if term not in color_map: color_map[term] = colors[color_index % len(colors)] color_index += 1 return color_map[term] - def replace_match(match): term = match.group(0) color = get_color(term) return f'{term}' - for key, value in parameters_dict.items(): pattern_key = re.compile(r'\b' + re.escape(key) + r'\b') pattern_value = re.compile(r'\b' + re.escape(str(value)) + r'\b') sentence = re.sub(pattern_key, replace_match, sentence) sentence = re.sub(pattern_value, replace_match, sentence) - return sentence def remove_duplicates(lst): @@ -156,7 +191,7 @@ def extract_last_error_sentence_from_list(log): basic_types.append('Any') class Model: - def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama3, # # gpt-4-turbo + def __init__(self, logger, device, model_llm_type="gpt-4o-mini-2024-07-18"): # llama3, # # gpt-4-turbo # # gpt-3.5-turbo-0125 # IO self.image_folder = "./tmp/images/" os.makedirs(self.image_folder, exist_ok=True) @@ -166,6 +201,7 @@ def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama self.LIB = "scanpy" with open(f'./data/standard_process/{self.LIB}/centroids.pkl', 'rb') as f: self.centroids = pickle.load(f) + self.debugging_mode=False self.execution_visualize = True self.keywords = ["dca", "magic", "phate", "palantir", "trimap", "sam", "phenograph", "wishbone", "sandbag", "cyclone", "spring_project", "cellbrowser"] # "harmony", self.success_history_API = [] @@ -606,7 +642,7 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T pass #self.callback_func('log', "\n - " + "\n - ".join(steps_list), "Multi step Task Planning") # "LLM return valid steps list, start executing...\n" + self.add_query(steps_list) - self.callback_func('log', "Ongoing subtask and remaining subtasks: \n → "+ '\n - '.join(self.user_query_list), "Task Planning") + self.callback_func('log', "Current and remaining tasks: \n → "+ '\n - '.join(self.user_query_list), "Task Planning") sub_task = self.get_query() if not sub_task: raise ValueError("sub_task is empty!") @@ -616,21 +652,25 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T sub_task = user_input else: sub_task = user_input - # we correct the subtask description before retrieving API - if len([i['code'] for i in self.executor.execute_code if i['success']=='True'])>0: # for non-first subtasks + # we correct the task description before retrieving API + if len([i['code'] for i in self.executor.execute_code if i['success']=='True'])>0: # for non-first tasks retrieved_apis = self.retriever.retrieving(sub_task, top_k=23) - retrieved_apis = [i for i in retrieved_apis if not ((any(keyword in i for keyword in self.keywords)) and ('external' in i))] + # remove external API, as it deprecates in scanpy + if self.LIB=='scanpy': + retrieved_apis = [i for i in retrieved_apis if not ((any(keyword in i for keyword in self.keywords)) and ('external' in i))] + retrieved_apis = [i for i in retrieved_apis if i not in ['scanpy.pl.dpt_groups_pseudotime', 'scanpy.pl.dpt_timeseries']]# these two are deprecated according to https://github.com/scverse/scanpy/issues/3086 + #retrieved_apis = [i for i in retrieved_apis if not self.validate_class_attr_api(i)] retrieved_apis = retrieved_apis[:3] - prompt = self.prompt_factory.create_prompt("modify_subtask_correction", self.initial_goal_description, sub_task, + prompt = self.prompt_factory.create_prompt("modify_task_correction", self.initial_goal_description, sub_task, '\n'.join([i['code'] for i in self.executor.execute_code if i['success']=='True']), json.dumps({str(key): str(value) for key, value in self.executor.variables.items() if value['type'] not in ['function', 'module', 'NoneType']}), "\n".join(['def '+generate_function_signature(api, self.API_composite[api]['Parameters'])+':\n"""'+self.API_composite[api]['Docstring'] + '"""' for api in retrieved_apis]) - ) - self.logger.info('modified sub_task prompt: {}', prompt) + ) + self.logger.info('modified task prompt: {}', prompt) sub_task, _ = LLM_response(prompt, self.model_llm_type, history=[], kwargs={}) - self.logger.info('modified sub_task: {}', sub_task) - #self.callback_func('log', 'we modify the subtask as '+sub_task, 'Modify subtask description') - self.callback_func('log', sub_task, 'Polished subtask') + self.logger.info('modified task: {}', sub_task) + #self.callback_func('log', 'we modify the task as '+sub_task, 'Modify task description') + self.callback_func('log', sub_task, 'Polished task') else: pass # get sub_task after dialog prediction @@ -638,10 +678,24 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T self.logger.info('we filter those API with IO parameters!') #self.logger.info('self.user_query: {}', self.user_query) retrieved_names = self.retriever.retrieving(self.user_query, top_k=self.args_top_k+20) + if self.LIB=='scanpy': + retrieved_names = [i for i in retrieved_names if i not in ['scanpy.pl.dpt_groups_pseudotime', 'scanpy.pl.dpt_timeseries']]# these two are deprecated according to https://github.com/scverse/scanpy/issues/3086 + # get scores dictionary + query_embedding = self.retriever.embedder.encode(self.user_query, convert_to_tensor=True) + hits = util.semantic_search(query_embedding, self.retriever.corpus_embeddings, top_k=self.args_top_k+20, score_function=util.cos_sim) + api_score_mapping = {} + for hit in hits[0]: + api = self.retriever.corpus2tool[hit['corpus_id']] + score = hit['score'] + api_score_mapping[api] = score # Here are external API which need to be removed, from https://github.com/scverse/scanpy/issues/2717 - retrieved_names = [i for i in retrieved_names if not ((any(keyword in i for keyword in self.keywords)) and ('external' in i))] + if self.LIB=='scanpy': + retrieved_names = [i for i in retrieved_names if not ((any(keyword in i for keyword in self.keywords)) and ('external' in i))] + #retrieved_names = [i for i in retrieved_names if not self.validate_class_attr_api(i)] # Filter out the executed API retrieved_names = [i for i in retrieved_names if i not in self.success_history_API] + if self.LIB=='scanpy': + retrieved_names = [i for i in retrieved_names if i not in ['scanpy.pl.dpt_groups_pseudotime', 'scanpy.pl.dpt_timeseries']]# these two are deprecated according to https://github.com/scverse/scanpy/issues/3086 #self.logger.info('retrieved_names: {}', retrieved_names) # filter out APIs #self.logger.info('first_task_start: {}, self.loaded_files: {}', self.first_task_start, self.loaded_files) @@ -663,6 +717,13 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T #retrieved_names = [api_name for api_name in retrieved_names if all((not any(special_type in str(param['type']) for special_type in special_types)) for param_name, param in self.API_composite[api_name]['Parameters'].items())] self.logger.info('there exist files or we have already load some dataset, retrieved_names are: {}', retrieved_names) retrieved_names = retrieved_names[:self.args_top_k] + # send information card to frontend + api_descriptions = [self.all_apis_json[api].replace('.', '. ') for api in retrieved_names] + highlighted_descriptions = highlight_keywords(self.user_query, api_descriptions) + highlighted_descriptions = [api+' : '+desc + ' Similarity score: ' + str(api_score_mapping[api]) for api,desc in zip(retrieved_names, highlighted_descriptions)] + del api_score_mapping + prepared_desc = "Here are the retrieved API candidates with their similarity scores and keywords highlighed as evidence:\n - " + '\n - '.join(highlighted_descriptions) + self.callback_func('log', prepared_desc, 'API information retrieval') self.first_task_start = False self.logger.info("retrieved names: {}!", retrieved_names) # start retrieving names @@ -797,10 +858,10 @@ def extract_api_calls(text, library): idx_api+=1 self.filtered_api = [self.predicted_api_name] + self.filtered_api next_str += "Enter [-1]: No appropriate API, input inquiry manually\n" - #next_str += "Enter [-2]: Skip to the next subtask" + #next_str += "Enter [-2]: Skip to the next task" # for ambiguous API, we think that it might be executed more than once as ambiguous API sometimes work together # user can exit by entering -1 - # so we add it back to the subtask list to execute it again + # so we add it back to the task list to execute it again #self.add_query([self.user_query], mode='pre') # 240625: deprecate self.update_user_state("run_pipeline_after_ambiguous") self.initialize_tool() @@ -884,7 +945,7 @@ def run_pipeline_after_ambiguous(self,user_input): """if user_index==-2: sub_task = self.get_query() if self.user_query_list: - self.callback_func('log', "Ongoing subtask and remaining subtasks: \n → "+ '\n - '.join(self.user_query_list), "Task Planning") + self.callback_func('log', "Current and remaining tasks: \n → "+ '\n - '.join(self.user_query_list), "Task Planning") sub_task = self.get_query() self.user_query = sub_task self.update_user_state("run_pipeline") @@ -962,7 +1023,11 @@ def run_pipeline_after_fixing_API_selection(self,user_input): self.callback_func('log', "Could you confirm whether this API should be called?\nEnter [y]: Go on please.\nEnter [n]: Restart another turn", "User Confirmation") self.update_user_state("run_pipeline_after_doublechecking_API_selection") self.save_state_enviro() - + def validate_class_attr_api(self, api): + if '.'.join(api.split('.')[:-1]) in self.API_composite: + if self.API_composite['.'.join(api.split('.')[:-1])]['api_type']=='class': + return True + return False def run_pipeline_after_doublechecking_API_selection(self, user_input): self.initialize_tool() user_input = str(user_input).strip().lower() @@ -972,24 +1037,29 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): self.callback_func('log', "We will start another round. Could you re-enter your inquiry?", "Start another round") self.retry_modify_count = 0 self.save_state_enviro() - else: # if there is task planning, we just update this subtask + else: # if there is task planning, we just update this task self.retry_modify_count += 1 - self.callback_func('log', "As this subtask is not exactly what you want, we polish the subtask and re-run the code generation pipeline", "Continue to the same subtask") + self.callback_func('log', "As this task is not exactly what you want, we polish the task and re-run the code generation pipeline", "Continue to the same task") #sub_task = self.get_query() # polish and modify the sub_task - retrieved_apis = self.retriever.retrieving(user_input, top_k=23) - retrieved_apis = [i for i in retrieved_apis if (not any(keyword in i for keyword in self.keywords)) and ('external' in i)] + """retrieved_apis = self.retriever.retrieving(user_input, top_k=23) + if self.LIB=='scanpy': + retrieved_apis = [i for i in retrieved_apis if (not any(keyword in i for keyword in self.keywords)) and ('external' in i)] + retrieved_apis = [i for i in retrieved_apis if i not in ['scanpy.pl.dpt_groups_pseudotime', 'scanpy.pl.dpt_timeseries']] # these two are deprecated according to https://github.com/scverse/scanpy/issues/3086 + # remove class attribute API + #retrieved_apis = [i for i in retrieved_apis if not self.validate_class_attr_api(i)] + # filter out the executed API retrieved_apis = [i for i in retrieved_apis if i not in self.success_history_API] retrieved_apis = retrieved_apis[:3] - prompt = self.prompt_factory.create_prompt("modify_subtask_correction", + prompt = self.prompt_factory.create_prompt("modify_task_correction", self.initial_goal_description, self.user_query, '\n'.join([i['code'] for i in self.executor.execute_code if i['success']=='True']), json.dumps({str(key): str(value) for key, value in self.executor.variables.items() if value['type'] not in ['function', 'module', 'NoneType']}), - "\n".join(['def '+generate_function_signature(api, self.API_composite[api]['Parameters'])+':\n"""'+self.API_composite[api]['Docstring'] + '"""' for api in retrieved_apis]) - ) - self.user_query, _ = LLM_response(prompt, self.model_llm_type, history=[], kwargs={}) - #self.logger.info('Polished subtask: {}', self.user_query) - self.callback_func('log', self.user_query, 'Polished subtask') + "\n".join(['def '+generate_function_signature(api, self.API_composite[api]['Parameters'])+':\n"""'+self.API_composite[api]["Docstring"] + '"""' for api in retrieved_apis]) + )""" + #self.user_query, _ = LLM_response(prompt, self.model_llm_type, history=[], kwargs={}) + #self.logger.info('Polished task: {}', self.user_query) + #self.callback_func('log', self.user_query, 'Polished task') self.update_user_state("run_pipeline") self.save_state_enviro() self.run_pipeline(self.user_query, self.LIB, top_k=3, files=[],conversation_started=False,session_id=self.session_id) @@ -1002,8 +1072,8 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): # user_states didn't change return self.logger.info('self.predicted_api_name: {}', self.predicted_api_name) - if len([i['code'] for i in self.executor.execute_code if i['success']=='True'])>0: # for non-first subtasks - prompt = self.prompt_factory.create_prompt("modify_subtask_parameters", self.initial_goal_description, self.user_query, + if len([i['code'] for i in self.executor.execute_code if i['success']=='True'])>0: # for non-first tasks + prompt = self.prompt_factory.create_prompt("modify_task_parameters", self.initial_goal_description, self.user_query, '\n'.join([i['code'] for i in self.executor.execute_code if i['success']=='True']), json.dumps({str(key): str(value) for key, value in self.executor.variables.items() if value['type'] not in ['function', 'module', 'NoneType']}), "\n"+ 'def '+generate_function_signature(self.predicted_api_name, self.API_composite[self.predicted_api_name]['Parameters'])+':\n"""'+self.API_composite[self.predicted_api_name]['Docstring'] + '"""') @@ -1105,7 +1175,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): predicted_parameters = {key: value for key, value in predicted_parameters.items() if value not in [None, "None", "null"] or key in required_param_list} self.logger.info('after filtering, predicted_parameters: {}', predicted_parameters) colored_sentence = label_sentence(self.user_query, predicted_parameters) - self.callback_func('log', colored_sentence, 'Highlight parameters value in polished subtask description') + self.callback_func('log', 'Here are the task description with keywords highlighted as evidence: \n' + colored_sentence, 'Polished task description') #self.logger.info('colored_sentence: {}', colored_sentence) # generate api_calling self.logger.info('self.API_composite[self.predicted_api_name]: {}', self.API_composite[self.predicted_api_name]) @@ -1388,7 +1458,10 @@ def run_pipeline_after_entering_params(self, user_input): if 'encoded=False' in self.execution_code: self.execution_code += '\n'+self.execution_code.replace('encoded=False', 'encoded=True').replace('result_1', 'result_2') self.logger.info('==>api_params_list: {}, execution_code: {}', api_params_list, self.execution_code) - self.callback_func('code', self.execution_code, "Executed code") + #if not self.debugging_mode: # avoid repeating showing the code + if True: + self.callback_func('code', self.execution_code, "Executed code") + self.debugging_mode=False # LLM response api_data_single = self.API_composite[self.predicted_api_name] api_docstring = 'def '+generate_function_signature(self.predicted_api_name, self.API_composite[self.predicted_api_name]['Parameters'])+':\n"""'+self.API_composite[self.predicted_api_name]['Docstring'] + '"""' @@ -1471,8 +1544,10 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input): self.logger.info('content: {}', content) # show the new variable if self.last_execute_code['code'] and self.last_execute_code['success']=='True': - if self.retry_execution_count>0: - self.callback_func('code', self.execution_code, "Executed code", enhance_indexxxx=False) + if self.retry_execution_count>0 and (self.retry_execution_countno Class type API') diff --git a/src/prompt/promptgenerator.py b/src/prompt/promptgenerator.py index dad4f72..bf40db6 100644 --- a/src/prompt/promptgenerator.py +++ b/src/prompt/promptgenerator.py @@ -144,11 +144,10 @@ def build_prompt(self, LIB, goal_description, data_list=[]): prompt = f""" Create step-by-step task plan with subtasks to achieve the goal. The tone should vary among queries: polite, straightforward, casual. -Each subtask has 15-20 words, be clear and concise for the scope of one single API's functionality from PyPI library {LIB}. Omit API name from subtask. +Each subtask has 10-20 words, be clear and concise for the scope of one single API's functionality from PyPI library {LIB}. Omit API name from subtask. Split the subtask into two or more subtasks if it contains more than one action. Using `Filtering ...` together with `Normalize ...` instead of `Filtering and Normalizing.` -Include only essential subtasks, with a range of 4 to 7 tasks. When arranging tasks, consider the logical order and dependencies. -Integrate visualization subtasks between each step. The last two subtasks MUST be visualization subtasks. +Integrate visualization tasks after each analytical step. Ensure the plan has around 5 tasks (maximum 7, minimum 3), with the last one exclusively for visualization. Focus on essential actions only. Include Data description only in data loading subtask. Ensure Goal-Oriented Task Structuring, place the goal description at the beginning of each subtask. Only respond in JSON format strictly enclosed in double quotes, adhering to the Response Format. @@ -159,12 +158,12 @@ def build_prompt(self, LIB, goal_description, data_list=[]): Response: {{"plan": [ "step 1: Load pre-processed Imaging Mass Cytometry data.", -"step 2: Visualize cell type clusters in spatial context to identify distributions of apoptotic and tumor cells among others.", -"step 3: Calculate co-occurrence of cell types across spatial dimensions, focusing on interactions between basal CK tumor cells and T cells.", -"step 4: Visualize co-occurrence results to understand cell type interactions and their spatial patterns.", -"step 5: Compute neighborhood enrichment to assess spatial proximity and interactions of cell clusters.", -"step 6: Visualize neighborhood enrichment results to highlight enriched or depleted interactions among cell clusters.", -"step 7: Visualize the distribution and interaction of all identified cell types", +"step 2: Could you show cell type clusters in spatial context?", +"step 3: Please calculate co-occurrence of cell types across spatial dimensions.", +"step 4: I want you to visualize co-occurrence results.", +"step 5: Compute neighborhood enrichment.", +"step 6: Can you plot neighborhood enrichment results?", +"step 7: How to display the distribution and interaction of all identified cell types?" ]}} --- Now finish the goal with the following information: @@ -188,10 +187,11 @@ def build_prompt(self, api_docstring, namespace_variables, error_code, possible_ possible_solution_info = f"\nPossible solution from similar issues from Github Issue Discussion:\n{possible_solution}" else: possible_solution_info = "" - if api_examples and api_examples != "{}": - api_examples_info = f"\nAPI Usage examples: {api_examples}." - else: - api_examples_info = "" + # remove api_examples as it is already included in the api docstring + #if api_examples and api_examples != "{}": + # api_examples_info = f"\nAPI Usage examples: {api_examples}." + #else: + # api_examples_info = "" prompt = f""" Task: Analyze and correct the newest failed attempt Python script based on provided traceback information. Correct the latest failed attempt without repeating previous mistakes. @@ -218,10 +218,10 @@ def build_prompt(self, api_docstring, namespace_variables, error_code, possible_ Success execution History: {success_history_code} Existing Namespace variables: {namespace_variables} Current Goal: {goal_description} -History Failed Attempts with their tracebacks: {error_code} +History Failed Attempts with their tracebacks:\n {error_code} {possible_solution_info}{api_examples_info} API Docstring: {api_docstring}. -Response Format: {{"analysis": "Explain how to correct the bug.", "code": "Corrected code"}} +Response Format: {{"analysis": "Explain how to correct the bug in 2 sentences, including the reason of the bug, and how to correct.", "code": "Contain the Corrected failed attempt code, exclude code from `success execution history`, exclude `save_plot_with_timestamp()`"}} """ # You only need to keep required parameters from previous trial codes, only keep minimum optional parameters necessary for task. Remove optional parameters from error code which cause the problem. Please ensure that required parameters are passed in their proper positional order, as keyword arguments should only be used for optional parameters. You only need to include the task related correct code in your response, do not repeat other API from the success execution history in your response. For parameters starting with 'result_', use only those that exist in the namespace. Do not generate inexist variables. return prompt @@ -245,8 +245,8 @@ def build_prompt(self, main_goal, current_subtask, execution_history, namespace Example: main goal: Use Scanpy to finish trajectory inference using the PAGA method. -original subtask description: Please perform trajectory inference in this step. -refined subtask description: Please perform trajectory inference using the PAGA method in this step. +original subtask description: Please perform trajectory inference. +refined subtask description: Please perform trajectory inference using the PAGA method. Example: Main goal: Use Scanpy to conduct gene annotation on dataset 3k PBMCs. @@ -364,9 +364,9 @@ def create_prompt(self, prompt_type, *args): return ExecutorPromptBuilder().build_prompt(*args) #elif prompt_type == 'subtask_code': # return SubtaskCodePromptBuilder().build_prompt(*args) - elif prompt_type == 'modify_subtask_parameters': + elif prompt_type == 'modify_task_parameters': return ModifySubtaskPromptBuilder().build_prompt(*args) - elif prompt_type == 'modify_subtask_correction': + elif prompt_type == 'modify_task_correction': return ModifySubtaskCorrectionPromptBuilder().build_prompt(*args) else: raise ValueError("Unknown prompt type")