From 9bf55bfb852a457f478e35d52afc462b4e6c0363 Mon Sep 17 00:00:00 2001 From: Amara Spano Date: Mon, 21 Sep 2020 01:03:58 +0200 Subject: [PATCH] docstring for paper.py --- scripts/paper.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/scripts/paper.py b/scripts/paper.py index 2a98379..39d8a14 100644 --- a/scripts/paper.py +++ b/scripts/paper.py @@ -7,6 +7,19 @@ def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_depth=10): + """Basic test for the random forest classifier + + Args: + protocol (str): protocol to use + variables (1d-array): list of desired variables (features) + setname (str): name of the dataset to load + nb_tree_per_forest: number of decision trees in the forest + max_depth: max depth of the trees + Returns: + numpy.ndarray: A 2D array (with a dtype of int) containing the confusion matrix. + Raises: + None + """ # get train data train = database.get(protocol, 'train', database.CLASSES, variables, setname) # make and train model @@ -23,6 +36,15 @@ def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_ def pretty_confusion_matrix(cm): + """Adds labels to confusion matrix + + Args: + cm (numpy.ndarray): A 2D array (with a dtype of int) containing the confusion matrix. + Returns: + str: nicely formatted confusion matrix for printing + Raises: + None + """ classes = np.array([database.CLASSES]) table = tabulate(np.vstack((np.hstack(([[""]], classes)), np.hstack((classes.T, cm))))) @@ -30,6 +52,15 @@ def pretty_confusion_matrix(cm): def test_impact_nb_trees(tabnum): + """Evaluates and print the impact of the number of trees per forest on the classifiers performance + + Args: + tabnum (int): first confusion matrix numbering + Returns: + None + Raises: + None + """ nb_trees = [1, 5, 10] print("\nImpact of number of trees per forest") for n, p in enumerate(database.PROTOCOLS): @@ -43,6 +74,15 @@ def test_impact_nb_trees(tabnum): print(pretty_confusion_matrix(cm)) def test_impact_tree_depth(tabnum): + """Evaluates and print the impact of the trees depth on the classifiers performance + + Args: + tabnum (int): first confusion matrix numbering + Returns: + None + Raises: + None + """ depths = [1, 5, 10] print("\nImpact of trees maximum depth") for n, p in enumerate(database.PROTOCOLS):