Skip to content

Commit

Permalink
Merge pull request #16 from sdevenes/feature/paper
Browse files Browse the repository at this point in the history
Feature/paper
  • Loading branch information
sdevenes authored Sep 21, 2020
2 parents d6480b7 + dd2a6b9 commit f3bf608
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
2 changes: 1 addition & 1 deletion scripts/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ def get(protocol, subset, classes=CLASSES, variables=VARIABLES, setname='csh101'
retval = split_data(load(setname), subset, PROTOCOLS[protocol])
varindex = [VARIABLES.index(k) for k in variables]
retval = dict([(k, retval[k][:, varindex]) for k in classes])
return np.array([retval[k] for k in classes])
return np.array([retval[k] for k in classes], dtype=object)
62 changes: 62 additions & 0 deletions scripts/paper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python
from tabulate import tabulate
import algorithm
import database
import analysis
import numpy as np


def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_depth=10):
# get train data
train = database.get(protocol, 'train', database.CLASSES, variables, setname)
# make and train model
model = algorithm.Model(nb_tree_per_forest, max_depth)
model.train(train)
# get test data
test = database.get(protocol, 'test', database.CLASSES, variables, setname)
test_labels = algorithm.make_labels(test).astype(int)
# make prediction on test
test_predictions = model.predict(test)
# get and return confusion matrix
cm = analysis.get_confusion_matrix(test_predictions, test_labels)
return cm


def pretty_confusion_matrix(cm):
classes = np.array([database.CLASSES])
table = tabulate(np.vstack((np.hstack(([[""]], classes)),
np.hstack((classes.T, cm)))))
return table


def test_impact_nb_trees(tabnum):
nb_trees = [1, 5, 10]
print("\nImpact of number of trees per forest")
for n, p in enumerate(database.PROTOCOLS):
for m, nb_tree_per_forest in enumerate(nb_trees):
print("\nTable {table_number}: Confusion matrix with {nb_trees} tree(s) for Protocol `{protocol}`".format(
table_number=(n * len(nb_trees)) + m + tabnum,
protocol=p,
nb_trees=nb_tree_per_forest)
)
cm = base_test(p, database.VARIABLES, nb_tree_per_forest=nb_tree_per_forest)
print(pretty_confusion_matrix(cm))

def test_impact_tree_depth(tabnum):
depths = [1, 5, 10]
print("\nImpact of trees maximum depth")
for n, p in enumerate(database.PROTOCOLS):
for m, max_depth in enumerate(depths):
print("\nTable {table_number}: Confusion matrix with trees maximum depth of {max_depth} for Protocol `{protocol}`".format(
table_number=(n * len(depths)) + m + tabnum,
protocol=p,
max_depth=max_depth)
)
cm = base_test(p, database.VARIABLES, max_depth=max_depth, nb_tree_per_forest=10)
print(pretty_confusion_matrix(cm))


if __name__ == '__main__':
print("Main script for Human Activity Recognition with Random Forest classifier")
test_impact_nb_trees(1)
test_impact_tree_depth(7)

0 comments on commit f3bf608

Please sign in to comment.