-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotResults.py
111 lines (93 loc) · 3.85 KB
/
plotResults.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
"""
Usage:
%s <csvfile>... [--outputFile=<outputFile>] [--xlim=<xlimit>] [--ylim=<ylimit>] [--mdr=<mdr>] [--telescope=<telescope>]
%s (-h | --help)
%s --version
Options:
-h --help Show this screen.
--version Show version.
--outputFile=<outputFile> Place to store the outputfile.
--xlim=<xlimit> Plot x limit [default: 1.0]
--ylim=<ylimit> Set y limit [default: 1.0]
--mdr=<mdr> Missed detection rate [default: 0.04]
--telescope=<telescope> Telescope
Example:
%s output_results.csv
"""
import sys
__doc__ = __doc__ % (sys.argv[0], sys.argv[0], sys.argv[0], sys.argv[0])
from docopt import docopt
from gkutils import Struct, cleanOptions
import numpy as np
import pandas as pd
#from sklearn.metrics import roc_curve, auc
from sklearn.metrics import auc
from plotROC import roc_curve
import optparse
import matplotlib.pyplot as plt
SMALL_SIZE = 14
MEDIUM_SIZE = 18
BIGGER_SIZE = 25
TINY_SIZE = 12
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=TINY_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=TINY_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE - 1) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
plt.rcParams["font.family"] = "serif"
plt.rcParams['mathtext.fontset'] = 'dejavuserif'
def plot_roc(fpr, tpr,roc_auc,roc):
roc.plot(fpr,tpr,lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
roc.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
roc.set_xlabel('False Positive Rate')
roc.set_ylabel('True Positive Rate')
roc.set_title('ROC curve')
roc.legend(loc="lower right")
def plot_tradeoff(mdr, fpr, tradeoff, intercept=[], xlim = 1.0, ylim = 1.0, title = 'Detection Error Tradeoff'):
tradeoff.plot(mdr,fpr,lw=2)
if len(intercept) > 0:
tradeoff.plot([0,intercept[0],intercept[0],intercept[0]], [intercept[1],intercept[1],intercept[1],0], color='black', linestyle='--')
tradeoff.set_xlabel('Missed detection rate')
tradeoff.set_ylabel('False positive rate')
# tradeoff.set_title(title)
#tradeoff.set_xlim(0,0.25)
#tradeoff.set_ylim(0,0.2)
tradeoff.set_xlim(0,xlim)
tradeoff.set_ylim(0,ylim)
def plotResults(files, outputfile, options = None):
#fig, (roc, tradeoff) = plt.subplots(1,2,sharey=False)
fig, (tradeoff) = plt.subplots(1,1,sharey=False)
for file in files:
plotTitle = 'Detection Error Tradeoff'
if options.telescope is not None:
plotTitle += ' (' + options.telescope + ')'
data = pd.read_csv(file, names=['file', 'tag', 'prediction'])
y = data['tag']
scores = data['prediction']
#fpr,tpr,thresholds = roc_curve(y, scores, pos_label=1)
fpr,tpr,thresholds = roc_curve(np.array(y), np.array(scores))
mdrSet = float(options.mdr)
fpr_at_mdrSet = (fpr[np.where(1-tpr<=mdrSet)[0]][-1])
print("[+]%.3lf%% mdr gives " % (mdrSet*100) + str(fpr[np.where(1-tpr<=mdrSet)[0]][-1]*100) + "% fpr")
print(" [+] threshold : %.3lf"%(thresholds[np.where(1-tpr<=mdrSet)[0]][-1]))
roc_auc = auc(fpr, tpr)
mdr = 1-tpr
plot_tradeoff(mdr, fpr, tradeoff, intercept=[mdrSet,fpr_at_mdrSet], xlim = float(options.xlim), ylim = float(options.ylim), title = plotTitle)
#plot_roc(fpr,tpr,roc_auc,roc)
plt.tight_layout()
if options is not None and options.outputFile is not None:
plt.savefig(outputfile)
else:
plt.show()
def main():
opts = docopt(__doc__, version='0.1')
opts = cleanOptions(opts)
# Use utils.Struct to convert the dict into an object for compatibility with old optparse code.
options = Struct(**opts)
print (options.csvfile)
plotResults(options.csvfile, options.outputFile, options = options)
if __name__=='__main__':
main()