forked from easezyc/Multitask-Recommendation-Library
-
Notifications
You must be signed in to change notification settings - Fork 1
/
draw_metrics.py
101 lines (88 loc) · 4.28 KB
/
draw_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from matplotlib import pyplot as plt
from sklearn.metrics import *
import matplotlib.pyplot as plt
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, matthews_corrcoef, f1_score, precision_score, recall_score
from torchmetrics.classification import BinaryAUROC, BinaryAccuracy, BinaryAveragePrecision, BinaryMatthewsCorrCoef, BinaryF1Score, BinaryPrecision, BinaryRecall
import torch
def tensor_wrapper(func):
def wrapped(y_test, y_pred):
y_test = torch.Tensor(y_test).cuda()
y_pred = torch.Tensor(y_pred).cuda()
return float(func(y_pred, y_test))
return wrapped
roc_auc_score = tensor_wrapper(BinaryAUROC(thresholds=None).cuda())
accuracy_score = tensor_wrapper(BinaryAccuracy().cuda())
# balanced_accuracy_score = te
matthews_corrcoef = tensor_wrapper(BinaryMatthewsCorrCoef().cuda())
f1_score = tensor_wrapper(BinaryF1Score().cuda())
precision_score = tensor_wrapper(BinaryPrecision().cuda())
recall_score = tensor_wrapper(BinaryRecall().cuda())
average_precision_score = tensor_wrapper(BinaryAveragePrecision(thresholds=None))
def best_threshold(y_test, y_pred_prob, metric=f1_score):
best_threshold_value, best_score = 0, 0
# for threshold in sorted(y_pred_prob):
# ordered_set = sorted(set(y_pred_prob))
# ordered_set = set(y_pred_prob)
ordered_set = y_pred_prob
ordered_set = tqdm(ordered_set, desc="Trying different thresholds...")
for threshold in ordered_set:
# TODO 这里其实没有覆盖所有的决策可能性
# y_pred = y_pred_prob>threshold
y_pred = y_pred_prob>=threshold
new_score = metric(y_test, y_pred)
if new_score>=best_score:
best_score = new_score
best_threshold_value = threshold
return best_threshold_value, best_score
# best_threshold_value, best_score = best_threshold(y_test, y_pred_prob, metric=precision_score)
# best_threshold_value, best_score = best_threshold(y_test, y_pred_prob, metric=balanced_accuracy_score)
# best_threshold_value, best_score = best_threshold(y_test, y_pred_prob, metric=matthews_corrcoef)
# best_threshold_value, best_score = best_threshold(y_test, y_pred_prob)
# best_threshold_value, best_score
def fast_evaluation(y_test, y_pred_prob, threshold=None, metric=balanced_accuracy_score,return_threshold=False):
if threshold is None:
threshold, _ = best_threshold(y_test, y_pred_prob, metric=metric)
y_pred = y_pred_prob>=threshold
d = dict(roc_auc=roc_auc_score(y_test, y_pred_prob),
accuracy=accuracy_score(y_test, y_pred),
balanced_accuracy=balanced_accuracy_score(y_test, y_pred),
mcc=matthews_corrcoef(y_test, y_pred),
f1=f1_score(y_test, y_pred),
precision=precision_score(y_test, y_pred),
recall=recall_score(y_test, y_pred))
if return_threshold:
return d, threshold
return d
# fast_evaluation(y_test, y_pred_prob, threshold=best_threshold_value)
def plot_auc(y_test, y_pred_prob,
curve_name=None, title="受试者工作特征曲线",
xlabel="假正例率", ylabel="真正例率"):
fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)
if curve_name is None:
curve_name = f"auc={roc_auc:.2f}"
fig, ax = plt.subplots()
ax.plot(fpr, tpr, color='orange', label=curve_name)
ax.plot([0, 1], [0, 1], color='darkblue', linestyle='--')
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.legend()
return roc_auc, fig
def plot_pr(y_test, y_pred_prob,
curve_name=None, title="精确率-召回率曲线",
xlabel="召回率", ylabel="精确率"):
precision, recall, thresholds = precision_recall_curve(y_test, y_pred_prob)
if curve_name is None:
curve_name = f"auc={auc(recall, precision):.2f}"
fig, ax = plt.subplots()
ax.plot(recall, precision, color='orange', label=curve_name)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.legend()
return auc(recall, precision), fig