-
Notifications
You must be signed in to change notification settings - Fork 0
/
mainTestClassifier.py
76 lines (63 loc) · 3.12 KB
/
mainTestClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import numpy as np
from utils.DatasetOptions import DatasetOptions
from utils.Dataset import Dataset
from utils.Results import Results
from learning.ClassifierRF import ClassifierRF
from learning.ClassifierRF import OptionsRF
from learning.ClassifierLogisticRegression import ClassifierLogisticRegression
from learning.ClassifierLogisticRegression import OptionsLogisticRegression
import helpers.constants as constantsPATREC
if __name__ == '__main__':
dirProject = '/home/thomas/fusessh/scicore/projects/patrec'
dirData = os.path.join(dirProject, 'data');
dirResultsBase = os.path.join(dirProject, 'results');
dirModelsBase = os.path.join(dirProject, 'classifiers')
dict_options_dataset_training = {
'dir_data': dirData,
'data_prefix': 'patrec',
'dataset': '20122015',
'encoding': 'categorical',
'newfeatures': {'names': constantsPATREC.NEW_FEATURES},
'featurereduction': None,
'grouping': 'verylightgrouping',
'filtering': 'EntlassBereich_Gyn'
}
dict_options_dataset_testing = {
'dir_data': dirData,
'data_prefix': 'patrec',
'dataset': '20162017',
'encoding': 'categorical',
'newfeatures': {'names': constantsPATREC.NEW_FEATURES},
'featurereduction': None,
'grouping': 'verylightgrouping',
'filtering': 'EntlassBereich_Gyn'
}
options_training = DatasetOptions(dict_options_dataset_training);
dataset_training = Dataset(dataset_options=options_training);
dict_opt_rf = {'n_estimators': 500, 'max_depth': 50};
options_rf = OptionsRF(dirModelsBase, options_training.getFilenameOptions(filteroptions=True), options_clf=dict_opt_rf);
clf_rf = ClassifierRF(options_rf);
dict_opt_lr = {'penalty': 'l1', 'C': 0.5};
options_lr = OptionsLogisticRegression(dirModelsBase, options_training.getFilenameOptions(filteroptions=True), options_clf=dict_opt_lr);
clf_lr = ClassifierLogisticRegression(options_lr);
options_clf = options_lr
clf = clf_lr;
options_testing = DatasetOptions(dict_options_dataset_testing);
dataset_testing = Dataset(dataset_options=options_testing);
results_all_runs_test = Results(dirResultsBase, options_training, options_clf, 'test', options_testing);
early_readmission_flagname = options_testing.getEarlyReadmissionFlagname();
test_aucs = [];
num_runs = 10;
for k in range(0, num_runs):
df_balanced_test = dataset_testing.getBalancedSubSet();
clf.loadFromFile(k);
results_test = clf.predict(df_balanced_test, early_readmission_flagname);
auc_test = results_test.getAUC();
test_aucs.append(auc_test);
print('test auc: ' + str(auc_test));
results_all_runs_test.addResultsSingleRun(results_test);
results_all_runs_test.writeResultsToFileDataset();
print('')
print('mean test auc: ' + str(np.mean(np.array(test_aucs))))
print('')