From b8aab31f55db7691dd916d12feefe71c37f47c12 Mon Sep 17 00:00:00 2001 From: Amara Spano Date: Sun, 20 Sep 2020 23:12:33 +0200 Subject: [PATCH 1/3] fixed deprecated warning in database.py --- scripts/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/database.py b/scripts/database.py index d3876f1..85a2cae 100644 --- a/scripts/database.py +++ b/scripts/database.py @@ -113,4 +113,4 @@ def get(protocol, subset, classes=CLASSES, variables=VARIABLES, setname='csh101' retval = split_data(load(setname), subset, PROTOCOLS[protocol]) varindex = [VARIABLES.index(k) for k in variables] retval = dict([(k, retval[k][:, varindex]) for k in classes]) - return np.array([retval[k] for k in classes]) + return np.array([retval[k] for k in classes], dtype=object) From 05406cd9cd17134b8def004505568348afa8b91a Mon Sep 17 00:00:00 2001 From: Amara Spano Date: Sun, 20 Sep 2020 23:54:06 +0200 Subject: [PATCH 2/3] added papers.py and implemented basic test + confusion matrix printing --- scripts/paper.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 scripts/paper.py diff --git a/scripts/paper.py b/scripts/paper.py new file mode 100644 index 0000000..de1d9ad --- /dev/null +++ b/scripts/paper.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +from tabulate import tabulate +import algorithm +import database +import analysis +import numpy as np + + +def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_depth=10): + # get train data + train = database.get(protocol, 'train', database.CLASSES, variables, setname) + # make and train model + model = algorithm.Model(nb_tree_per_forest, max_depth) + model.train(train) + # get test data + test = database.get(protocol, 'test', database.CLASSES, variables, setname) + test_labels = algorithm.make_labels(test).astype(int) + # make prediction on test + test_predictions = model.predict(test) + # get and return confusion matrix + cm = analysis.get_confusion_matrix(test_predictions, test_labels) + return cm + + +def pretty_confusion_matrix(cm): + classes = np.array([database.CLASSES]) + table = tabulate(np.vstack((np.hstack(([[""]], classes)), + np.hstack((classes.T, cm))))) + return table + + +if __name__ == '__main__': + print("Main script for Human Activity Recognition with Random Forest classifier") + print(pretty_confusion_matrix(base_test('proto1', database.VARIABLES))) From dd2a6b9c440d86928f83a172294f246bf8288886 Mon Sep 17 00:00:00 2001 From: Amara Spano Date: Mon, 21 Sep 2020 00:11:52 +0200 Subject: [PATCH 3/3] implemented test tree depth and number of trees --- scripts/paper.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/scripts/paper.py b/scripts/paper.py index de1d9ad..2a98379 100644 --- a/scripts/paper.py +++ b/scripts/paper.py @@ -29,6 +29,34 @@ def pretty_confusion_matrix(cm): return table +def test_impact_nb_trees(tabnum): + nb_trees = [1, 5, 10] + print("\nImpact of number of trees per forest") + for n, p in enumerate(database.PROTOCOLS): + for m, nb_tree_per_forest in enumerate(nb_trees): + print("\nTable {table_number}: Confusion matrix with {nb_trees} tree(s) for Protocol `{protocol}`".format( + table_number=(n * len(nb_trees)) + m + tabnum, + protocol=p, + nb_trees=nb_tree_per_forest) + ) + cm = base_test(p, database.VARIABLES, nb_tree_per_forest=nb_tree_per_forest) + print(pretty_confusion_matrix(cm)) + +def test_impact_tree_depth(tabnum): + depths = [1, 5, 10] + print("\nImpact of trees maximum depth") + for n, p in enumerate(database.PROTOCOLS): + for m, max_depth in enumerate(depths): + print("\nTable {table_number}: Confusion matrix with trees maximum depth of {max_depth} for Protocol `{protocol}`".format( + table_number=(n * len(depths)) + m + tabnum, + protocol=p, + max_depth=max_depth) + ) + cm = base_test(p, database.VARIABLES, max_depth=max_depth, nb_tree_per_forest=10) + print(pretty_confusion_matrix(cm)) + + if __name__ == '__main__': print("Main script for Human Activity Recognition with Random Forest classifier") - print(pretty_confusion_matrix(base_test('proto1', database.VARIABLES))) + test_impact_nb_trees(1) + test_impact_tree_depth(7)