diff --git a/docs/PyPI2APP.md b/docs/PyPI2APP.md index 97213bd..5dcca38 100644 --- a/docs/PyPI2APP.md +++ b/docs/PyPI2APP.md @@ -99,7 +99,7 @@ Tips: 5. Train the api/non-api classification model. ```bash -python models/chitchat_classification.py --LIB ${LIB} +python models/chitchat_classification.py --LIB ${LIB} --ratio_1_to_3 1.0 --ratio_2_to_3 1.0 --embed_method st_trained # or train a classification model on multicorpus of 12 bio-tools. # python models/chitchat_classification_multicorpus.py ``` @@ -137,6 +137,7 @@ export HUGGINGPATH=./hugging_models CUDA_VISIBLE_DEVICES=0 # if you use gpu python inference/retriever_finetune_inference.py \ --retrieval_model_path ./hugging_models/retriever_model_finetuned/${LIB}/assigned \ + --max_seq_length 256 \ --corpus_tsv_path ./data/standard_process/${LIB}/retriever_train_data/corpus.tsv \ --input_query_file ./data/standard_process/${LIB}/API_inquiry_annotate.json \ --idx_file ./data/standard_process/${LIB}/API_instruction_testval_query_ids.json \ diff --git a/src/models/chitchat_classification.py b/src/models/chitchat_classification.py index 2af8b75..68415e9 100644 --- a/src/models/chitchat_classification.py +++ b/src/models/chitchat_classification.py @@ -1,10 +1,36 @@ -import argparse -import os +import argparse, os, json, torch import pandas as pd from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import numpy as np +from tqdm import tqdm +from sentence_transformers import SentenceTransformer, models +from transformers import BertModel, BertTokenizer +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f'using device: {device}') +# load args +parser = argparse.ArgumentParser(description="Process data with a specified library.") +parser.add_argument("--LIB", type=str, default="scanpy", required=True, help="Library to use for data processing.") +parser.add_argument("--ratio_1_to_3", type=float, default=1.0, help="Ratio of data1 to data3.") +parser.add_argument("--ratio_2_to_3", type=float, default=1.0, help="Ratio of data2 to data3.") +parser.add_argument("--embed_method", type=str, choices=["st_untrained", "st_trained"], default="st_trained", help="The method for embeddings: st_untrained, or st_trained") +args = parser.parse_args() + +# load unpretrained model +word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256) +pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) +unpretrained_model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device) +# load pretrained model +pretrained_model = SentenceTransformer(f"./hugging_models/retriever_model_finetuned/{args.LIB}/assigned", device=device) + +def bert_embed(model,tokenizer,text): + inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) + outputs = model(**inputs) + return outputs.last_hidden_state.mean(1).squeeze().detach().cpu().numpy() +def sentence_transformer_embed(model, texts): + embeddings = model.encode(texts, convert_to_tensor=True) + return embeddings def process_topicalchat(): import json import pandas as pd @@ -18,6 +44,7 @@ def process_topicalchat(): questions.append(message) df = pd.DataFrame({"Question": questions, "Source": "topical-chat"}) df.to_csv("./data/others-data/dialogue_questions.csv", index=False) + def process_chitchat(): import pandas as pd import glob @@ -32,8 +59,6 @@ def process_chitchat(): combined_data.to_csv("./data/others-data/combined_data.csv", sep=',', index=False) print("Data has been combined and saved as combined_data.csv") def process_apiquery(lib_name, filename="API_inquiry.json", start_id=0, index_save=True): - import json - import pandas as pd with open(f'./data/standard_process/{lib_name}/{filename}', 'r') as file: json_data = json.load(file) filtered_data = [entry for entry in json_data if entry['query_id'] >= start_id] @@ -62,8 +87,17 @@ def sampledata_combine(data1, data2, data3, test_data3, train_count_data1=1000, test_data.to_csv('./data/others-data/test_data.csv', index=False) print("Train data and test data have been saved.") -def calculate_centroid(tfidf_matrix): - return np.mean(tfidf_matrix, axis=0).A[0] # Here we convert the matrix to array +def calculate_centroid(data, embed_method): + if embed_method == "st_untrained": + print('Using pretrained model!!!') + embeddings = np.array([sentence_transformer_embed(unpretrained_model, text).cpu() for text in tqdm(data, desc="Processing with unpretrained sentencetransformer BERT")]) + elif embed_method == "st_trained": + embeddings = np.array([sentence_transformer_embed(pretrained_model, text).cpu() for text in tqdm(data, desc="Processing with pretrained sentencetransformer BERT")]) + else: + raise NotImplementedError + if torch.is_tensor(embeddings): + embeddings = embeddings.cpu().numpy() + return np.mean(embeddings, axis=0) def predict_by_similarity(user_query_vector, centroids, labels): similarities = [cosine_similarity(user_query_vector, centroid.reshape(1, -1)) for centroid in centroids] @@ -117,11 +151,6 @@ def plot_tsne_distribution_modified(lib_name, train_data, test_data, vectorizer, plt.savefig(f'./plot/{lib_name}/chitchat_test_tsne_modified.png') def main(): - parser = argparse.ArgumentParser(description="Process data with a specified library.") - parser.add_argument("--LIB", type=str, default="scanpy", required=True, help="Library to use for data processing.") - parser.add_argument("--ratio_1_to_3", type=float, default=4.0, help="Ratio of data1 to data3.") - parser.add_argument("--ratio_2_to_3", type=float, default=4.0, help="Ratio of data2 to data3.") - args = parser.parse_args() process_topicalchat() process_chitchat() tmp_data = process_apiquery(args.LIB) @@ -137,13 +166,9 @@ def main(): data1 = pd.read_csv('./data/others-data/dialogue_questions.csv') data2 = pd.read_csv('./data/others-data/combined_data.csv') data3 = pd.read_csv(f'./data/standard_process/{args.LIB}/api_data.csv') - #min_length = min(len(data1), len(data2), len(data3)) - #train_sample_num = int(min_length * 0.8) - #test_sample_num = int(min_length * 0.2) - + min_length_train = len(data3) min_length_test = len(test_data3) - #train_ratio = 0.8 # no need to split now train_sample_num1 = int(min_length_train*args.ratio_1_to_3) train_sample_num2 = int(min_length_train*args.ratio_2_to_3) test_sample_num1 = int(min_length_test*args.ratio_1_to_3) @@ -155,23 +180,20 @@ def main(): train_data1 = train_data[train_data['Source'] == 'chitchat-data'] train_data2 = train_data[train_data['Source'] == 'topical-chat'] train_data3 = train_data[train_data['Source'] == 'api-query'] + print('length of train_data1, train_data2, train_data3: ', len(train_data1), len(train_data2), len(train_data3)) print('The real ratio for data1, data2 based on API data is: ', len(train_data1)/len(train_data3), len(train_data2)/len(train_data3)) - vectorizer = TfidfVectorizer() all_data = pd.concat([train_data1, train_data2, train_data3], ignore_index=True) - vectorizer.fit(all_data['Question']) - tfidf_matrix1 = vectorizer.transform(train_data1['Question']) - tfidf_matrix2 = vectorizer.transform(train_data2['Question']) - tfidf_matrix3 = vectorizer.transform(train_data3['Question']) - centroid1 = calculate_centroid(tfidf_matrix1) - centroid2 = calculate_centroid(tfidf_matrix2) - centroid3 = calculate_centroid(tfidf_matrix3) + centroid1 = calculate_centroid(train_data1['Question'],args.embed_method) + centroid2 = calculate_centroid(train_data2['Question'],args.embed_method) + centroid3 = calculate_centroid(train_data3['Question'],args.embed_method) centroids = [centroid1, centroid2, centroid3] labels = ['chitchat-data', 'topical-chat', 'api-query'] - def calculate_accuracy(test_data, vectorizer, centroids, labels): + def calculate_accuracy(test_data, centroids, labels): correct_predictions = 0 for index, row in test_data.iterrows(): - user_query_vector = vectorizer.transform([row['Question']]) + #user_query_vector = vectorizer.transform([row['Question']]) + user_query_vector = bert_embed(bert_trans_model, tokenizer,row['Question']).reshape(1, -1) predicted_label = predict_by_similarity(user_query_vector, centroids, labels) actual_label = row['Source'] if predicted_label == actual_label: @@ -180,23 +202,14 @@ def calculate_accuracy(test_data, vectorizer, centroids, labels): test_data_api = test_data[test_data['Source'] == 'api-query'] test_data_chitchat = test_data[test_data['Source'] == 'chitchat-data'] test_data_topical = test_data[test_data['Source'] == 'topical-chat'] - #test_data_chitchat = test_data[(test_data['Source'] == 'chitchat-data') | (test_data['Source'] == 'topical-chat')] - c_api_accuracy, correct_predictions_api, total_predictions_api = calculate_accuracy(test_data_api, vectorizer, centroids, labels) - print(f"Accuracy on test data (API queries): {c_api_accuracy:.2f}%") - # Calculate accuracy for Chitchat data - c_chitchat_accuracy, correct_predictions_chitchat, total_predictions_chitchat = calculate_accuracy(test_data_chitchat, vectorizer, centroids, labels) - print(f"Accuracy on test data (Chitchat queries): {c_chitchat_accuracy:.2f}%") - c_topical_accuracy, correct_predictions_topical, total_predictions_topical = calculate_accuracy(test_data_topical, vectorizer, centroids, labels) - print(f"Accuracy on test data (Topical queries): {c_topical_accuracy:.2f}%") - c3_accuracy, correct_predictions_c3, total_predictions = calculate_accuracy(test_data, vectorizer, centroids, labels) + + c3_accuracy, correct_predictions_c3, total_predictions = calculate_accuracy(test_data, centroids, labels) print(f"Accuracy on test data on 3 clusters: {c3_accuracy:.2f}") - assert correct_predictions_api + correct_predictions_chitchat + correct_predictions_topical == correct_predictions_c3, "Sum of correct predictions in subsets does not equal total correct predictions" - assert total_predictions_api + total_predictions_chitchat + total_predictions_topical == total_predictions, "Sum of item counts in subsets does not equal total item count in test data" - correct_predictions = 0 for index, row in test_data.iterrows(): - user_query_vector = vectorizer.transform([row['Question']]) + #user_query_vector = vectorizer.transform([row['Question']]) + user_query_vector = bert_embed(bert_trans_model, tokenizer,row['Question']).reshape(1, -1) predicted_label = predict_by_similarity(user_query_vector, centroids, labels) actual_label = row['Source'] if (actual_label=='api-query' and predicted_label=='api-query') or (actual_label!='api-query' and predicted_label!='api-query'): @@ -207,19 +220,13 @@ def calculate_accuracy(test_data, vectorizer, centroids, labels): import time start_time = time.time() import pickle - with open(f'./data/standard_process/{args.LIB}/vectorizer.pkl', 'wb') as f: - pickle.dump(vectorizer, f) - print(f"Vectorizer saved. Time taken: {time.time() - start_time:.2f} seconds") - + with open(f'./data/standard_process/{args.LIB}/centroids.pkl', 'wb') as f: pickle.dump(centroids, f) start_time = time.time() os.makedirs(f"./plot/{args.LIB}", exist_ok=True) print(f"Centroids saved. Time taken: {time.time() - start_time:.2f} seconds") start_time = time.time() - # Call the modified function to plot - #plot_tsne_distribution_modified(args.LIB, train_data, test_data, vectorizer, labels, c2_accuracy) - #print(f"t-SNE plot created. Time taken: {time.time() - start_time:.2f} seconds") - + if __name__=='__main__': main()