Skip to content

Commit

Permalink
implemented test tree depth and number of trees
Browse files Browse the repository at this point in the history
  • Loading branch information
spanoamara committed Sep 20, 2020
1 parent 05406cd commit dd2a6b9
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion scripts/paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ def pretty_confusion_matrix(cm):
return table


def test_impact_nb_trees(tabnum):
nb_trees = [1, 5, 10]
print("\nImpact of number of trees per forest")
for n, p in enumerate(database.PROTOCOLS):
for m, nb_tree_per_forest in enumerate(nb_trees):
print("\nTable {table_number}: Confusion matrix with {nb_trees} tree(s) for Protocol `{protocol}`".format(
table_number=(n * len(nb_trees)) + m + tabnum,
protocol=p,
nb_trees=nb_tree_per_forest)
)
cm = base_test(p, database.VARIABLES, nb_tree_per_forest=nb_tree_per_forest)
print(pretty_confusion_matrix(cm))

def test_impact_tree_depth(tabnum):
depths = [1, 5, 10]
print("\nImpact of trees maximum depth")
for n, p in enumerate(database.PROTOCOLS):
for m, max_depth in enumerate(depths):
print("\nTable {table_number}: Confusion matrix with trees maximum depth of {max_depth} for Protocol `{protocol}`".format(
table_number=(n * len(depths)) + m + tabnum,
protocol=p,
max_depth=max_depth)
)
cm = base_test(p, database.VARIABLES, max_depth=max_depth, nb_tree_per_forest=10)
print(pretty_confusion_matrix(cm))


if __name__ == '__main__':
print("Main script for Human Activity Recognition with Random Forest classifier")
print(pretty_confusion_matrix(base_test('proto1', database.VARIABLES)))
test_impact_nb_trees(1)
test_impact_tree_depth(7)

0 comments on commit dd2a6b9

Please sign in to comment.