-
Notifications
You must be signed in to change notification settings - Fork 0
/
voting_model.py
74 lines (56 loc) · 1.85 KB
/
voting_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from sklearn_utils import load_trained_sklearn_models, prepare_data, evaluate
from keras_utils import get_keras_scores
import numpy as np
import glob
import joblib
import ast
import re
from voting_ensemble import SoftVoteClassifier
from sklearn.metrics import f1_score
clfs, sklearn_scores = load_trained_sklearn_models()
keras_scores = get_keras_scores(normalize_scores=False)
clf_weights = []
for i in range(len(sklearn_scores) // 100):
score_sum = 0.0
for j in range(100):
loc = i * 100 + j
score = sklearn_scores[loc]
score_sum += score
clf_weights.append(score_sum / 100.)
for i in range(len(keras_scores) // 10):
score_sum = 0.0
for j in range(10):
loc = i * 10 + j
score = keras_scores[loc]
score_sum += score
clf_weights.append(score_sum / 10.)
score_sum_ = sum(clf_weights)
clf_weights = [s / score_sum_ for s in clf_weights]
model = SoftVoteClassifier(None, weights=clf_weights)
def fit_voting_classifier(dataset='full'):
np.random.seed(1000)
# print('Loading data')
data, labels = prepare_data(mode='test', dataset=dataset)
if dataset == 'full':
pred_dir = 'test/*/'
elif dataset == 'obama':
pred_dir = 'obama/*/'
else:
pred_dir = 'romney/*/'
preds = model.predict_proba_dir(pred_dir)
evaluate(labels, np.argmax(preds, axis=1))
# print('Saving predictions')
# np.save('test/voting/voting_predictions.npy', preds)
#
# data, labels = prepare_data()
#
# model = SoftVoteClassifier(clfs, weights=None)
# preds = model.predict_proba(data)
# print(preds.shape)
# np.save('models/voting/voting_predictions.npy', preds)
#
# print('VotingClassifier fit complete!')
if __name__ == '__main__':
#fit_voting_classifier()
fit_voting_classifier(dataset='obama')
fit_voting_classifier(dataset='romney')