Skip to content

Commit

Permalink
update chitchat model
Browse files Browse the repository at this point in the history
  • Loading branch information
DoraDong-2023 committed Jan 21, 2024
1 parent 08bbd8f commit 116a223
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 48 deletions.
3 changes: 2 additions & 1 deletion docs/PyPI2APP.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Tips:

5. Train the api/non-api classification model.
```bash
python models/chitchat_classification.py --LIB ${LIB}
python models/chitchat_classification.py --LIB ${LIB} --ratio_1_to_3 1.0 --ratio_2_to_3 1.0 --embed_method st_trained
# or train a classification model on multicorpus of 12 bio-tools.
# python models/chitchat_classification_multicorpus.py
```
Expand Down Expand Up @@ -137,6 +137,7 @@ export HUGGINGPATH=./hugging_models
CUDA_VISIBLE_DEVICES=0 # if you use gpu
python inference/retriever_finetune_inference.py \
--retrieval_model_path ./hugging_models/retriever_model_finetuned/${LIB}/assigned \
--max_seq_length 256 \
--corpus_tsv_path ./data/standard_process/${LIB}/retriever_train_data/corpus.tsv \
--input_query_file ./data/standard_process/${LIB}/API_inquiry_annotate.json \
--idx_file ./data/standard_process/${LIB}/API_instruction_testval_query_ids.json \
Expand Down
101 changes: 54 additions & 47 deletions src/models/chitchat_classification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
import argparse
import os
import argparse, os, json, torch
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, models
from transformers import BertModel, BertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'using device: {device}')

# load args
parser = argparse.ArgumentParser(description="Process data with a specified library.")
parser.add_argument("--LIB", type=str, default="scanpy", required=True, help="Library to use for data processing.")
parser.add_argument("--ratio_1_to_3", type=float, default=1.0, help="Ratio of data1 to data3.")
parser.add_argument("--ratio_2_to_3", type=float, default=1.0, help="Ratio of data2 to data3.")
parser.add_argument("--embed_method", type=str, choices=["st_untrained", "st_trained"], default="st_trained", help="The method for embeddings: st_untrained, or st_trained")
args = parser.parse_args()

# load unpretrained model
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
unpretrained_model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
# load pretrained model
pretrained_model = SentenceTransformer(f"./hugging_models/retriever_model_finetuned/{args.LIB}/assigned", device=device)

def bert_embed(model,tokenizer,text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(1).squeeze().detach().cpu().numpy()
def sentence_transformer_embed(model, texts):
embeddings = model.encode(texts, convert_to_tensor=True)
return embeddings
def process_topicalchat():
import json
import pandas as pd
Expand All @@ -18,6 +44,7 @@ def process_topicalchat():
questions.append(message)
df = pd.DataFrame({"Question": questions, "Source": "topical-chat"})
df.to_csv("./data/others-data/dialogue_questions.csv", index=False)

def process_chitchat():
import pandas as pd
import glob
Expand All @@ -32,8 +59,6 @@ def process_chitchat():
combined_data.to_csv("./data/others-data/combined_data.csv", sep=',', index=False)
print("Data has been combined and saved as combined_data.csv")
def process_apiquery(lib_name, filename="API_inquiry.json", start_id=0, index_save=True):
import json
import pandas as pd
with open(f'./data/standard_process/{lib_name}/{filename}', 'r') as file:
json_data = json.load(file)
filtered_data = [entry for entry in json_data if entry['query_id'] >= start_id]
Expand Down Expand Up @@ -62,8 +87,17 @@ def sampledata_combine(data1, data2, data3, test_data3, train_count_data1=1000,
test_data.to_csv('./data/others-data/test_data.csv', index=False)
print("Train data and test data have been saved.")

def calculate_centroid(tfidf_matrix):
return np.mean(tfidf_matrix, axis=0).A[0] # Here we convert the matrix to array
def calculate_centroid(data, embed_method):
if embed_method == "st_untrained":
print('Using pretrained model!!!')
embeddings = np.array([sentence_transformer_embed(unpretrained_model, text).cpu() for text in tqdm(data, desc="Processing with unpretrained sentencetransformer BERT")])
elif embed_method == "st_trained":
embeddings = np.array([sentence_transformer_embed(pretrained_model, text).cpu() for text in tqdm(data, desc="Processing with pretrained sentencetransformer BERT")])
else:
raise NotImplementedError
if torch.is_tensor(embeddings):
embeddings = embeddings.cpu().numpy()
return np.mean(embeddings, axis=0)

def predict_by_similarity(user_query_vector, centroids, labels):
similarities = [cosine_similarity(user_query_vector, centroid.reshape(1, -1)) for centroid in centroids]
Expand Down Expand Up @@ -117,11 +151,6 @@ def plot_tsne_distribution_modified(lib_name, train_data, test_data, vectorizer,
plt.savefig(f'./plot/{lib_name}/chitchat_test_tsne_modified.png')

def main():
parser = argparse.ArgumentParser(description="Process data with a specified library.")
parser.add_argument("--LIB", type=str, default="scanpy", required=True, help="Library to use for data processing.")
parser.add_argument("--ratio_1_to_3", type=float, default=4.0, help="Ratio of data1 to data3.")
parser.add_argument("--ratio_2_to_3", type=float, default=4.0, help="Ratio of data2 to data3.")
args = parser.parse_args()
process_topicalchat()
process_chitchat()
tmp_data = process_apiquery(args.LIB)
Expand All @@ -137,13 +166,9 @@ def main():
data1 = pd.read_csv('./data/others-data/dialogue_questions.csv')
data2 = pd.read_csv('./data/others-data/combined_data.csv')
data3 = pd.read_csv(f'./data/standard_process/{args.LIB}/api_data.csv')
#min_length = min(len(data1), len(data2), len(data3))
#train_sample_num = int(min_length * 0.8)
#test_sample_num = int(min_length * 0.2)


min_length_train = len(data3)
min_length_test = len(test_data3)
#train_ratio = 0.8 # no need to split now
train_sample_num1 = int(min_length_train*args.ratio_1_to_3)
train_sample_num2 = int(min_length_train*args.ratio_2_to_3)
test_sample_num1 = int(min_length_test*args.ratio_1_to_3)
Expand All @@ -155,23 +180,20 @@ def main():
train_data1 = train_data[train_data['Source'] == 'chitchat-data']
train_data2 = train_data[train_data['Source'] == 'topical-chat']
train_data3 = train_data[train_data['Source'] == 'api-query']

print('length of train_data1, train_data2, train_data3: ', len(train_data1), len(train_data2), len(train_data3))
print('The real ratio for data1, data2 based on API data is: ', len(train_data1)/len(train_data3), len(train_data2)/len(train_data3))
vectorizer = TfidfVectorizer()
all_data = pd.concat([train_data1, train_data2, train_data3], ignore_index=True)
vectorizer.fit(all_data['Question'])
tfidf_matrix1 = vectorizer.transform(train_data1['Question'])
tfidf_matrix2 = vectorizer.transform(train_data2['Question'])
tfidf_matrix3 = vectorizer.transform(train_data3['Question'])
centroid1 = calculate_centroid(tfidf_matrix1)
centroid2 = calculate_centroid(tfidf_matrix2)
centroid3 = calculate_centroid(tfidf_matrix3)
centroid1 = calculate_centroid(train_data1['Question'],args.embed_method)
centroid2 = calculate_centroid(train_data2['Question'],args.embed_method)
centroid3 = calculate_centroid(train_data3['Question'],args.embed_method)
centroids = [centroid1, centroid2, centroid3]
labels = ['chitchat-data', 'topical-chat', 'api-query']
def calculate_accuracy(test_data, vectorizer, centroids, labels):
def calculate_accuracy(test_data, centroids, labels):
correct_predictions = 0
for index, row in test_data.iterrows():
user_query_vector = vectorizer.transform([row['Question']])
#user_query_vector = vectorizer.transform([row['Question']])
user_query_vector = bert_embed(bert_trans_model, tokenizer,row['Question']).reshape(1, -1)
predicted_label = predict_by_similarity(user_query_vector, centroids, labels)
actual_label = row['Source']
if predicted_label == actual_label:
Expand All @@ -180,23 +202,14 @@ def calculate_accuracy(test_data, vectorizer, centroids, labels):
test_data_api = test_data[test_data['Source'] == 'api-query']
test_data_chitchat = test_data[test_data['Source'] == 'chitchat-data']
test_data_topical = test_data[test_data['Source'] == 'topical-chat']
#test_data_chitchat = test_data[(test_data['Source'] == 'chitchat-data') | (test_data['Source'] == 'topical-chat')]
c_api_accuracy, correct_predictions_api, total_predictions_api = calculate_accuracy(test_data_api, vectorizer, centroids, labels)
print(f"Accuracy on test data (API queries): {c_api_accuracy:.2f}%")
# Calculate accuracy for Chitchat data
c_chitchat_accuracy, correct_predictions_chitchat, total_predictions_chitchat = calculate_accuracy(test_data_chitchat, vectorizer, centroids, labels)
print(f"Accuracy on test data (Chitchat queries): {c_chitchat_accuracy:.2f}%")
c_topical_accuracy, correct_predictions_topical, total_predictions_topical = calculate_accuracy(test_data_topical, vectorizer, centroids, labels)
print(f"Accuracy on test data (Topical queries): {c_topical_accuracy:.2f}%")
c3_accuracy, correct_predictions_c3, total_predictions = calculate_accuracy(test_data, vectorizer, centroids, labels)

c3_accuracy, correct_predictions_c3, total_predictions = calculate_accuracy(test_data, centroids, labels)
print(f"Accuracy on test data on 3 clusters: {c3_accuracy:.2f}")

assert correct_predictions_api + correct_predictions_chitchat + correct_predictions_topical == correct_predictions_c3, "Sum of correct predictions in subsets does not equal total correct predictions"
assert total_predictions_api + total_predictions_chitchat + total_predictions_topical == total_predictions, "Sum of item counts in subsets does not equal total item count in test data"

correct_predictions = 0
for index, row in test_data.iterrows():
user_query_vector = vectorizer.transform([row['Question']])
#user_query_vector = vectorizer.transform([row['Question']])
user_query_vector = bert_embed(bert_trans_model, tokenizer,row['Question']).reshape(1, -1)
predicted_label = predict_by_similarity(user_query_vector, centroids, labels)
actual_label = row['Source']
if (actual_label=='api-query' and predicted_label=='api-query') or (actual_label!='api-query' and predicted_label!='api-query'):
Expand All @@ -207,19 +220,13 @@ def calculate_accuracy(test_data, vectorizer, centroids, labels):
import time
start_time = time.time()
import pickle
with open(f'./data/standard_process/{args.LIB}/vectorizer.pkl', 'wb') as f:
pickle.dump(vectorizer, f)
print(f"Vectorizer saved. Time taken: {time.time() - start_time:.2f} seconds")


with open(f'./data/standard_process/{args.LIB}/centroids.pkl', 'wb') as f:
pickle.dump(centroids, f)
start_time = time.time()
os.makedirs(f"./plot/{args.LIB}", exist_ok=True)
print(f"Centroids saved. Time taken: {time.time() - start_time:.2f} seconds")
start_time = time.time()
# Call the modified function to plot
#plot_tsne_distribution_modified(args.LIB, train_data, test_data, vectorizer, labels, c2_accuracy)
#print(f"t-SNE plot created. Time taken: {time.time() - start_time:.2f} seconds")


if __name__=='__main__':
main()

0 comments on commit 116a223

Please sign in to comment.