-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtests.py
66 lines (48 loc) · 2.4 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from random import seed
from unittest import TestCase, main
from numpy import dtype, float64
from sklearn.datasets import make_classification
from tabulate import tabulate
from grid_search_classifier import GridSearchClassifier, scikit_learn_classifiers
class TestGridSearchClassifierOutput(TestCase):
"""
Tests that monitor changes to GridSearchClassifier. These tests allow for changes to be made to the source code
and to understand if outputs are effected.
"""
def __init__(self, *args, **kwargs):
super(TestGridSearchClassifierOutput, self).__init__(*args, **kwargs)
train_set, train_targets = make_classification(n_samples=1000, n_features=4,
n_informative=2, n_redundant=0,
random_state=0, shuffle=False)
test_set, test_targets = make_classification(n_samples=1000, n_features=4,
n_informative=2, n_redundant=0,
random_state=6, shuffle=False)
classifiers = scikit_learn_classifiers()
seed(123)
self.output = GridSearchClassifier(train_set, test_set, train_targets,
test_targets, classifiers, [1, 0, 0]).fit()
print('Full output')
print(list(self.output.columns))
print(tabulate(self.output))
def test_column_names(self):
self.assertListEqual(list(self.output.columns),
['accuracy', 'train_time', 'test_time', 'ranks'])
def test_shape(self):
self.assertTupleEqual(self.output.shape, (8, 4))
def test_dtypes(self):
for x in list(self.output.columns):
try:
self.assertEqual(self.output[x].dtype, float64)
except AssertionError as e:
raise AssertionError(f'{e}. Column: {x}')
def test_min_of_columns(self):
for x in list(self.output.columns):
try:
self.assertGreaterEqual(self.output[x].min(), 0)
except AssertionError as e:
raise AssertionError(f'{e}. Column: {x}')
def test_max_accuracy(self):
self.assertLessEqual(self.output['accuracy'].max(), 1)
def test_index_dtype(self):
self.assertEqual(self.output.index.dtype, dtype('O'))
main() if __name__ == '__main__' else None