Skip to content

Commit

Permalink
Added code for metric calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
chtmp223 authored Dec 15, 2023
1 parent e0df264 commit a42a7f0
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions script/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,34 @@ def construct_sentences(p2_root, removed):
if not removed_branches:
sentences.append(branch_str)
return sentences


def calculate_purity(true_col, pred_col, df):
'''
Calculate harmonic purity between two set of clusterings
df: a Pandas data frame containing two columns (true_col and pred_col)
true_col: column containing a ground-truth label for each document
pred_col: column containing a predicted label for each document
'''
contingency_matrix = metrics.cluster.contingency_matrix(df[true_col], df[pred_col])
precision = contingency_matrix / contingency_matrix.sum(axis=0).reshape(1, -1)
recall = contingency_matrix / contingency_matrix.sum(axis=1).reshape(-1, 1)
f1 = 2 * (precision * recall) / (precision + recall)
f1 = np.nan_to_num(f1)
purity = (np.amax(precision, axis=0) * contingency_matrix.sum(axis=0)).sum() / contingency_matrix.sum()
inverse_purity = (np.amax(recall, axis=1) * contingency_matrix.sum(axis=1)).sum() / contingency_matrix.sum()
harmonic_purity = (np.amax(f1, axis=1) * contingency_matrix.sum(axis=1)).sum() / contingency_matrix.sum()
return (purity, inverse_purity, harmonic_purity)


def calculate_metrics(true_col, pred_col, df):
'''
Calculate topic alignment between df1 and df2 (harmonic purity, ARI, NMI)
df: a Pandas data frame containing two columns (true_col and pred_col)
true_col: column containing a ground-truth label for each document
pred_col: column containing a predicted label for each document
'''
purity, inverse_purity, harmonic_purity = calculate_purity(true_col, pred_col, df)
ari = metrics.adjusted_rand_score(df[true_col], df[pred_col])
mis = metrics.normalized_mutual_info_score(df[true_col], df[pred_col])
return (harmonic_purity, ari, mis)

0 comments on commit a42a7f0

Please sign in to comment.