Skip to content

Commit

Permalink
docstring for paper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
spanoamara committed Sep 20, 2020
1 parent f08a0fa commit 9bf55bf
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions scripts/paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@


def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_depth=10):
"""Basic test for the random forest classifier
Args:
protocol (str): protocol to use
variables (1d-array): list of desired variables (features)
setname (str): name of the dataset to load
nb_tree_per_forest: number of decision trees in the forest
max_depth: max depth of the trees
Returns:
numpy.ndarray: A 2D array (with a dtype of int) containing the confusion matrix.
Raises:
None
"""
# get train data
train = database.get(protocol, 'train', database.CLASSES, variables, setname)
# make and train model
Expand All @@ -23,13 +36,31 @@ def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_


def pretty_confusion_matrix(cm):
"""Adds labels to confusion matrix
Args:
cm (numpy.ndarray): A 2D array (with a dtype of int) containing the confusion matrix.
Returns:
str: nicely formatted confusion matrix for printing
Raises:
None
"""
classes = np.array([database.CLASSES])
table = tabulate(np.vstack((np.hstack(([[""]], classes)),
np.hstack((classes.T, cm)))))
return table


def test_impact_nb_trees(tabnum):
"""Evaluates and print the impact of the number of trees per forest on the classifiers performance
Args:
tabnum (int): first confusion matrix numbering
Returns:
None
Raises:
None
"""
nb_trees = [1, 5, 10]
print("\nImpact of number of trees per forest")
for n, p in enumerate(database.PROTOCOLS):
Expand All @@ -43,6 +74,15 @@ def test_impact_nb_trees(tabnum):
print(pretty_confusion_matrix(cm))

def test_impact_tree_depth(tabnum):
"""Evaluates and print the impact of the trees depth on the classifiers performance
Args:
tabnum (int): first confusion matrix numbering
Returns:
None
Raises:
None
"""
depths = [1, 5, 10]
print("\nImpact of trees maximum depth")
for n, p in enumerate(database.PROTOCOLS):
Expand Down

0 comments on commit 9bf55bf

Please sign in to comment.