Skip to content

Commit

Permalink
added papers.py and implemented basic test + confusion matrix printing
Browse files Browse the repository at this point in the history
  • Loading branch information
spanoamara committed Sep 20, 2020
1 parent b8aab31 commit 05406cd
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions scripts/paper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python
from tabulate import tabulate
import algorithm
import database
import analysis
import numpy as np


def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_depth=10):
# get train data
train = database.get(protocol, 'train', database.CLASSES, variables, setname)
# make and train model
model = algorithm.Model(nb_tree_per_forest, max_depth)
model.train(train)
# get test data
test = database.get(protocol, 'test', database.CLASSES, variables, setname)
test_labels = algorithm.make_labels(test).astype(int)
# make prediction on test
test_predictions = model.predict(test)
# get and return confusion matrix
cm = analysis.get_confusion_matrix(test_predictions, test_labels)
return cm


def pretty_confusion_matrix(cm):
classes = np.array([database.CLASSES])
table = tabulate(np.vstack((np.hstack(([[""]], classes)),
np.hstack((classes.T, cm)))))
return table


if __name__ == '__main__':
print("Main script for Human Activity Recognition with Random Forest classifier")
print(pretty_confusion_matrix(base_test('proto1', database.VARIABLES)))

0 comments on commit 05406cd

Please sign in to comment.