Skip to content

Commit

Permalink
include API_base when retrieving.
Browse files Browse the repository at this point in the history
  • Loading branch information
DoraDong-2023 committed Dec 28, 2023
1 parent 4173948 commit 4ab6043
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 20 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,21 @@ We provide a script for downloading models and datas from Google Drive for scanp
sh download_data_model.sh
```

Organize the downloaded files at `src/data` or `src/hugging_models` as follows:
Organize the downloaded files at `src/data` or `src/hugging_models` as follows (`base` are necessary):
```
data
├── conversations
├── others-data
└── standard_process
├── base
│   ├── API_composite.json
│ └── ...
├── scanpy
│   ├── API_composite.json
│   └── ...
├── {LIB}
│   ├── API_composite.json
├── └── ...
│   └── ...
└── ...
hugging_models
Expand Down
Binary file modified chatbot_ui_biomania/.DS_Store
Binary file not shown.
Binary file modified chatbot_ui_biomania/components/Chat/.DS_Store
Binary file not shown.
4 changes: 1 addition & 3 deletions src/configs/Base_cheatsheet.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"pandas.DataFrame.to_csv",
"pandas.DataFrame",
"pandas.concat",
"pandas.groupby"
"pandas.DataFrame.groupby"
],
"matplotlib": [
"matplotlib.pyplot.plot",
Expand Down Expand Up @@ -40,7 +40,6 @@
"sklearn.decomposition.NMF.fit"
],
"scipy": [
"scipy.sparse_matrix",
"scipy.sparse.csr_matrix"
],
"torch": [
Expand All @@ -54,7 +53,6 @@
"statsmodels.formula.api.ols"
],
"pynndescent": [
"pynndescent.NNDescent.fit",
"pynndescent.NNDescent.query"
],
"anndata": [
Expand Down
8 changes: 6 additions & 2 deletions src/dataloader/get_API_init_from_sourcecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def filter_optional_parameters(api_data):
def generate_api_callings(results, basic_types=['str', 'int', 'float', 'bool', 'list', 'dict', 'tuple', 'set', 'any', 'List', 'Dict']):
updated_results = {}
for api_name, api_info in results.items():
if api_info["api_type"] in ['function', 'method', 'class', 'functools.partial']:
if api_info["api_type"]: # in ['function', 'method', 'class', 'functools.partial']
# Update the optional_value key for each parameter
for param_name, param_details in api_info["Parameters"].items():
param_type = param_details.get('type')
Expand Down Expand Up @@ -684,7 +684,8 @@ def main_get_API_init(lib_name,lib_alias,analysis_path,api_html_path=None,api_tx
def main_get_API_basic(analysis_path,cheatsheet):
# STEP1: get API from cheatsheet, save to basic_API.json
#output_file = os.path.join(analysis_path,"API_base.json")
output_file = os.path.join('data','standard_process',"API_base.json")
os.makedirs(os.path.join('data','standard_process',"base"), exist_ok=True)
output_file = os.path.join('data','standard_process',"base", "API_init.json")
result_list = []
print('Start getting docparam from source')
for api in cheatsheet:
Expand All @@ -696,6 +697,9 @@ def main_get_API_basic(analysis_path,cheatsheet):
results = {r: results[r] for r in results if r in cheatsheet[api]}
result_list.append(results)
outputs = merge_jsons(result_list)
for api_name, api_info in outputs.items():
api_info['relevant APIs'] = []
api_info['type'] = 'singleAPI'
with open(output_file, 'w') as file:
file.write(json.dumps(outputs, indent=4))

Expand Down
6 changes: 5 additions & 1 deletion src/deploy/inference_dialog_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def reset_lib(self, lib_name):
new_path = '/'.join(parts)
retrieval_model_path = new_path
logging.info('load retrieval_model_path in: %s', retrieval_model_path)
self.retriever = ToolRetriever(LIB=lib_name,corpus_tsv_path=f"./data/standard_process/{lib_name}/retriever_train_data/corpus.tsv", model_path=retrieval_model_path)
self.retriever = ToolRetriever(LIB=lib_name,corpus_tsv_path=f"./data/standard_process/{lib_name}/retriever_train_data/corpus.tsv", model_path=retrieval_model_path, add_base=True)
logging.info('loaded retriever!')
#self.executor.execute_api_call(f"from data.standard_process.{self.LIB}.Composite_API import *", "import")
self.executor.execute_api_call(f"import {lib_name}", "import")
Expand Down Expand Up @@ -554,9 +554,13 @@ def clear_globals_with_prefix(self, prefix):
def load_llm_model(self):
self.llm, self.tokenizer = LLM_model()
def load_data(self, API_file):
# fix 231227, add API_base.json
with open(API_file, 'r') as json_file:
data = json.load(json_file)
with open("./data/standard_process/base/API_composite.json", 'r') as json_file:
base_data = json.load(json_file)
self.API_composite = data
self.API_composite.update(base_data)
def generate_file_loading_code(self, file_path, file_type):
# Define the loading code for each file type
file_loading_templates = {
Expand Down
4 changes: 4 additions & 0 deletions src/download_data_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ gdown --id 1PVsrsKnOdv0VrZw9sjkjsmQTm018MOG5 -O hugging_models/retriever_model_f
gdown --id 1X-3xnTba9Mxb8SZ8oIEdxhVrGMdFVP-i -O data/your_second_file
gdown --id 1NRKLDijLENR1vyQHFNT_vk5lLY4lL1CL -O data/your_third_file
gdown --id 1wgYY9CD1hPfqlUUFFqDIeh9l4nEu124t -O data/your_fourth_file
gdown --id 15vNIPK8ut8Hudbjwg_G1w0zJjs2gmmRb -O data/your_fifth_file

cd data
unzip your_second_file
unzip your_third_file
unzip your_fourth_file
unzip your_fifth_file
rm -rf your_second_file
rm -rf your_third_file
rm -rf your_fourth_file
rm -rf your_fifth_file
mv base ./standard_process/base
cd ..

cd hugging_models/retriever_model_finetuned
Expand Down
51 changes: 39 additions & 12 deletions src/inference/retriever_finetune_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,58 @@
import seaborn as sns

class ToolRetriever:
def __init__(self, LIB, corpus_tsv_path = "", model_path=""):
def __init__(self, LIB, corpus_tsv_path = "", model_path="", base_corpus_tsv_path="./data/standard_process/base/retriever_train_data/corpus.tsv",add_base=False):
#self.model_path = os.path.join(model_path,f"{LIB}","assigned")
self.model_path = model_path
self.build_retrieval_corpus(corpus_tsv_path)
self.shuffled_data = self.build_shuffle_data(LIB)
self.corpus_tsv_path = corpus_tsv_path
self.base_corpus_tsv_path = base_corpus_tsv_path
#self.build_retrieval_corpus(corpus_tsv_path)
self.build_and_merge_corpus(add_base=add_base)
self.shuffled_data = self.build_shuffle_data(LIB, add_base=add_base)
self.shuffled_queries = [item['query'] for item in self.shuffled_data]
self.shuffled_query_embeddings = self.embedder.encode(self.shuffled_queries, convert_to_tensor=True)
def build_shuffle_data(self,LIB):

def build_shuffle_data(self,LIB, add_base=True):
# add API_base, fix 231227
import json
import random
with open(f'./data/standard_process/{LIB}/API_inquiry_annotate.json', 'r') as f:
data = json.load(f)
def process_data(path, files_ids):
with open(f'{path}/API_inquiry_annotate.json', 'r') as f:
data = json.load(f)
return [dict(query=row['query'], gold=row['api_name']) for row in data if row['query_id'] not in files_ids['val'] and row['query_id'] not in files_ids['test']]
with open(f"./data/standard_process/{LIB}/API_instruction_testval_query_ids.json", 'r') as file:
files_ids = json.load(file)
shuffled = [dict(query=row['query'], gold=row['api_name']) for row in [i for i in data if i['query_id'] not in files_ids['val'] and i['query_id'] not in files_ids['test']]]
random.Random(0).shuffle(shuffled)
return shuffled
lib_files_ids = json.load(file)
lib_data = process_data(f'./data/standard_process/{LIB}', lib_files_ids)
with open(f"./data/standard_process/base/API_instruction_testval_query_ids.json", 'r') as base_file_ids:
base_files_ids = json.load(base_file_ids)
base_data = process_data('./data/standard_process/base', base_files_ids)
if add_base:
lib_data = lib_data + base_data
random.Random(0).shuffle(lib_data)
return lib_data
def build_retrieval_corpus(self, corpus_tsv_path):
self.corpus_tsv_path = corpus_tsv_path
documents_df = pd.read_csv(self.corpus_tsv_path, sep='\t')
corpus, self.corpus2tool = process_retrieval_document_query_version(documents_df)
corpus_ids = list(corpus.keys())
corpus = [corpus[cid] for cid in corpus_ids]
self.corpus = corpus
self.embedder = SentenceTransformer(self.model_path, device=device)
self.corpus_embeddings = self.embedder.encode(self.corpus, convert_to_tensor=True)
def build_and_merge_corpus(self, add_base=True):
# based on build_retrieval_corpus, add API_base.json, fix 231227
original_corpus_df = pd.read_csv(self.corpus_tsv_path, sep='\t')
additional_corpus_df = pd.read_csv(self.base_corpus_tsv_path, sep='\t')
if add_base:
combined_corpus_df = pd.concat([original_corpus_df, additional_corpus_df], ignore_index=True)
combined_corpus_df.reset_index(drop=True, inplace=True)
else:
combined_corpus_df = original_corpus_df
corpus, self.corpus2tool = process_retrieval_document_query_version(combined_corpus_df)
corpus_ids = list(corpus.keys())
corpus = [corpus[cid] for cid in corpus_ids]
self.corpus = corpus
self.embedder = SentenceTransformer(self.model_path, device=device)
self.corpus_embeddings = self.embedder.encode(self.corpus, convert_to_tensor=True)
def retrieving(self, query, top_k):
query_embedding = self.embedder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, self.corpus_embeddings, top_k=top_k, score_function=util.cos_sim) #170*
Expand Down Expand Up @@ -113,7 +140,7 @@ def compute_accuracy(retriever, data, args,name='train'):
val_ids = index_data['val']

# Step 2: Create a ToolRetriever instance
retriever = ToolRetriever(LIB = args.LIB, corpus_tsv_path=args.corpus_tsv_path, model_path=args.retrieval_model_path)
retriever = ToolRetriever(LIB = args.LIB, corpus_tsv_path=args.corpus_tsv_path, model_path=args.retrieval_model_path, add_base=False)
print(retriever.corpus[0])

total_queries = 0
Expand Down

0 comments on commit 4ab6043

Please sign in to comment.