From a42a7f0fc84a9325bfbb6ac2155dc8ac51f1b0db Mon Sep 17 00:00:00 2001 From: Chau Minh Pham <41503463+chtmp223@users.noreply.github.com> Date: Fri, 15 Dec 2023 10:49:29 -0500 Subject: [PATCH] Added code for metric calculation --- script/utils.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/script/utils.py b/script/utils.py index 154fe95..80a59e5 100644 --- a/script/utils.py +++ b/script/utils.py @@ -332,3 +332,34 @@ def construct_sentences(p2_root, removed): if not removed_branches: sentences.append(branch_str) return sentences + + +def calculate_purity(true_col, pred_col, df): + ''' + Calculate harmonic purity between two set of clusterings + df: a Pandas data frame containing two columns (true_col and pred_col) + true_col: column containing a ground-truth label for each document + pred_col: column containing a predicted label for each document + ''' + contingency_matrix = metrics.cluster.contingency_matrix(df[true_col], df[pred_col]) + precision = contingency_matrix / contingency_matrix.sum(axis=0).reshape(1, -1) + recall = contingency_matrix / contingency_matrix.sum(axis=1).reshape(-1, 1) + f1 = 2 * (precision * recall) / (precision + recall) + f1 = np.nan_to_num(f1) + purity = (np.amax(precision, axis=0) * contingency_matrix.sum(axis=0)).sum() / contingency_matrix.sum() + inverse_purity = (np.amax(recall, axis=1) * contingency_matrix.sum(axis=1)).sum() / contingency_matrix.sum() + harmonic_purity = (np.amax(f1, axis=1) * contingency_matrix.sum(axis=1)).sum() / contingency_matrix.sum() + return (purity, inverse_purity, harmonic_purity) + + +def calculate_metrics(true_col, pred_col, df): + ''' + Calculate topic alignment between df1 and df2 (harmonic purity, ARI, NMI) + df: a Pandas data frame containing two columns (true_col and pred_col) + true_col: column containing a ground-truth label for each document + pred_col: column containing a predicted label for each document + ''' + purity, inverse_purity, harmonic_purity = calculate_purity(true_col, pred_col, df) + ari = metrics.adjusted_rand_score(df[true_col], df[pred_col]) + mis = metrics.normalized_mutual_info_score(df[true_col], df[pred_col]) + return (harmonic_purity, ari, mis)