diff --git a/scripts/paper.py b/scripts/paper.py index de1d9ad..2a98379 100644 --- a/scripts/paper.py +++ b/scripts/paper.py @@ -29,6 +29,34 @@ def pretty_confusion_matrix(cm): return table +def test_impact_nb_trees(tabnum): + nb_trees = [1, 5, 10] + print("\nImpact of number of trees per forest") + for n, p in enumerate(database.PROTOCOLS): + for m, nb_tree_per_forest in enumerate(nb_trees): + print("\nTable {table_number}: Confusion matrix with {nb_trees} tree(s) for Protocol `{protocol}`".format( + table_number=(n * len(nb_trees)) + m + tabnum, + protocol=p, + nb_trees=nb_tree_per_forest) + ) + cm = base_test(p, database.VARIABLES, nb_tree_per_forest=nb_tree_per_forest) + print(pretty_confusion_matrix(cm)) + +def test_impact_tree_depth(tabnum): + depths = [1, 5, 10] + print("\nImpact of trees maximum depth") + for n, p in enumerate(database.PROTOCOLS): + for m, max_depth in enumerate(depths): + print("\nTable {table_number}: Confusion matrix with trees maximum depth of {max_depth} for Protocol `{protocol}`".format( + table_number=(n * len(depths)) + m + tabnum, + protocol=p, + max_depth=max_depth) + ) + cm = base_test(p, database.VARIABLES, max_depth=max_depth, nb_tree_per_forest=10) + print(pretty_confusion_matrix(cm)) + + if __name__ == '__main__': print("Main script for Human Activity Recognition with Random Forest classifier") - print(pretty_confusion_matrix(base_test('proto1', database.VARIABLES))) + test_impact_nb_trees(1) + test_impact_tree_depth(7)