-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_ood_detection.py
110 lines (80 loc) · 4.2 KB
/
eval_ood_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import torch
import numpy as np
import os
import warnings
from detector.detector import Detector
from model.metric import compute_traditional_odd
from sklearn.linear_model import LogisticRegressionCV
warnings.filterwarnings('ignore')
torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)
parser = argparse.ArgumentParser(description='Pytorch Detecting Out-of-distribution examples in neural networks')
parser.add_argument('--in-dataset', default="CIFAR-10", type=str, help='in-distribution dataset')
parser.add_argument('--name', required=True, type=str, help='the name of the model trained')
parser.add_argument('--model-arch', default='densenet', type=str, help='model architecture')
parser.add_argument('--gpu', default='1', type=str, help='gpu index')
parser.add_argument('--adv', help='L_inf OOD', action='store_true')
parser.add_argument('--corrupt', help='corrupted OOD', action='store_true')
parser.add_argument('--adv-corrupt', help='comp. OOD', action='store_true')
parser.add_argument('--in-dist-only', help='only evaluate in-distribution', action='store_true')
parser.add_argument('--out-dist-only', help='only evaluate out-distribution', action='store_true')
parser.add_argument('--method', default='energy', type=str, help='scoring function')
parser.add_argument('--cal-metric', help='calculate metric directly', action='store_true')
parser.add_argument('--epsilon', default=8.0, type=float, help='epsilon')
parser.add_argument('--iters', default=40, type=int,
help='attack iterations')
parser.add_argument('--iter-size', default=1.0, type=float, help='attack step size')
parser.add_argument('--severity-level', default=5, type=int, help='severity level')
parser.add_argument('--epochs', default=100, type=int,
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
help='mini-batch size')
parser.add_argument('--base-dir', default='output/ood_scores', type=str, help='result directory')
parser.add_argument('--layers', default=100, type=int,
help='total number of layers (default: 100)')
parser.add_argument('--depth', default=40, type=int,
help='depth of resnet')
parser.add_argument('--width', default=4, type=int,
help='width of resnet')
parser.set_defaults(argument=True)
args = parser.parse_args()
if __name__ == "__main__":
method_args = dict()
adv_args = dict()
mode_args = dict()
adv_args['epsilon'] = args.epsilon
adv_args['iters'] = args.iters
adv_args['iter_size'] = args.iter_size
adv_args['severity_level'] = args.severity_level
mode_args['in_dist_only'] = args.in_dist_only
mode_args['out_dist_only'] = args.out_dist_only
# out_datasets = ['LSUN', 'LSUN_resize', 'iSUN', 'dtd', 'SVHN']
out_datasets = ['SVHN']
if args.method == "msp":
detector = Detector(args, out_datasets, method_args, adv_args, mode_args)
detector.detect()
elif args.method == "odin":
method_args['temperature'] = 1000.0
detector = Detector(args, out_datasets, method_args, adv_args, mode_args)
detector.detect()
elif args.method == "mahalanobis":
sample_mean, precision, lr_weights, lr_bias, magnitude = np.load(
os.path.join('output/mahalanobis_hyperparams/', args.in_dataset, args.name, 'results.npy'), allow_pickle=True)
regressor = LogisticRegressionCV(cv=2).fit([[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]],
[0, 0, 1, 1])
regressor.coef_ = lr_weights
regressor.intercept_ = lr_bias
method_args['sample_mean'] = sample_mean
method_args['precision'] = precision
method_args['magnitude'] = magnitude
method_args['regressor'] = regressor
detector = Detector(args, out_datasets, method_args, adv_args, mode_args)
detector.detect()
elif args.method == "energy":
detector = Detector(args, out_datasets, method_args, adv_args, mode_args)
detector.detect()
else:
assert False, 'Not supported method'
compute_traditional_odd(args.base_dir, args.in_dataset, out_datasets, args.method, args.name)