diff --git a/script/utils.py b/script/utils.py index 154fe95..80a59e5 100644 --- a/script/utils.py +++ b/script/utils.py @@ -332,3 +332,34 @@ def construct_sentences(p2_root, removed): if not removed_branches: sentences.append(branch_str) return sentences + + +def calculate_purity(true_col, pred_col, df): + ''' + Calculate harmonic purity between two set of clusterings + df: a Pandas data frame containing two columns (true_col and pred_col) + true_col: column containing a ground-truth label for each document + pred_col: column containing a predicted label for each document + ''' + contingency_matrix = metrics.cluster.contingency_matrix(df[true_col], df[pred_col]) + precision = contingency_matrix / contingency_matrix.sum(axis=0).reshape(1, -1) + recall = contingency_matrix / contingency_matrix.sum(axis=1).reshape(-1, 1) + f1 = 2 * (precision * recall) / (precision + recall) + f1 = np.nan_to_num(f1) + purity = (np.amax(precision, axis=0) * contingency_matrix.sum(axis=0)).sum() / contingency_matrix.sum() + inverse_purity = (np.amax(recall, axis=1) * contingency_matrix.sum(axis=1)).sum() / contingency_matrix.sum() + harmonic_purity = (np.amax(f1, axis=1) * contingency_matrix.sum(axis=1)).sum() / contingency_matrix.sum() + return (purity, inverse_purity, harmonic_purity) + + +def calculate_metrics(true_col, pred_col, df): + ''' + Calculate topic alignment between df1 and df2 (harmonic purity, ARI, NMI) + df: a Pandas data frame containing two columns (true_col and pred_col) + true_col: column containing a ground-truth label for each document + pred_col: column containing a predicted label for each document + ''' + purity, inverse_purity, harmonic_purity = calculate_purity(true_col, pred_col, df) + ari = metrics.adjusted_rand_score(df[true_col], df[pred_col]) + mis = metrics.normalized_mutual_info_score(df[true_col], df[pred_col]) + return (harmonic_purity, ari, mis)