From f08a0fa6ee59a19f7045aca99c85662db2d2490d Mon Sep 17 00:00:00 2001 From: Amara Spano Date: Mon, 21 Sep 2020 00:52:47 +0200 Subject: [PATCH 1/3] docstring for database.py --- scripts/database.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/scripts/database.py b/scripts/database.py index 85a2cae..0c068e5 100644 --- a/scripts/database.py +++ b/scripts/database.py @@ -93,6 +93,15 @@ def load(setname='csh101'): + """Loads the dataset + + Args: + setname (str): name of the dataset to load + Returns: + dict of str : 2d-array: a dictionary mapping the classes names to their corresponding samples (1 row = 1 sample) + Raises: + None + """ data = dict([(k, []) for k in CLASSES]) with open(os.path.join('../data', setname, '{}.ann.features.csv'.format(setname)), 'rt') as f: reader = csv.reader(f, delimiter=',') @@ -105,11 +114,35 @@ def load(setname='csh101'): def split_data(data, subset, splits): + """Splits the data set + + Args: + data (dict of str : 2d-array): dataset to split + subset (str): subset to extract (train, validation or test) + splits (dict of str : tuple): a dictionary mapping the subsets to their range (from 0.0 to 1.0) + Returns: + dict of str : 2d-array: a dictionary mapping the classes names to their corresponding samples (1 row = 1 sample) + Raises: + None + """ return dict([(k, data[k][range(int(splits[subset][0] * data[k].shape[0]), int(splits[subset][1] * data[k].shape[0]))]) for k in data]) def get(protocol, subset, classes=CLASSES, variables=VARIABLES, setname='csh101'): + """Get the desired subset + + Args: + protocol (str): protocol to use + subset (str): subset to extract (train, validation or test) + classes (1d-array): list of desired classes + variables (1d-array): list of desired variables (features) + setname (str): name of the dataset to load + Returns: + numpy.ndarray: array of arrays containing the samples corresponding to 1 class in order + Raises: + None + """ retval = split_data(load(setname), subset, PROTOCOLS[protocol]) varindex = [VARIABLES.index(k) for k in variables] retval = dict([(k, retval[k][:, varindex]) for k in classes]) From 9bf55bfb852a457f478e35d52afc462b4e6c0363 Mon Sep 17 00:00:00 2001 From: Amara Spano Date: Mon, 21 Sep 2020 01:03:58 +0200 Subject: [PATCH 2/3] docstring for paper.py --- scripts/paper.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/scripts/paper.py b/scripts/paper.py index 2a98379..39d8a14 100644 --- a/scripts/paper.py +++ b/scripts/paper.py @@ -7,6 +7,19 @@ def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_depth=10): + """Basic test for the random forest classifier + + Args: + protocol (str): protocol to use + variables (1d-array): list of desired variables (features) + setname (str): name of the dataset to load + nb_tree_per_forest: number of decision trees in the forest + max_depth: max depth of the trees + Returns: + numpy.ndarray: A 2D array (with a dtype of int) containing the confusion matrix. + Raises: + None + """ # get train data train = database.get(protocol, 'train', database.CLASSES, variables, setname) # make and train model @@ -23,6 +36,15 @@ def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_ def pretty_confusion_matrix(cm): + """Adds labels to confusion matrix + + Args: + cm (numpy.ndarray): A 2D array (with a dtype of int) containing the confusion matrix. + Returns: + str: nicely formatted confusion matrix for printing + Raises: + None + """ classes = np.array([database.CLASSES]) table = tabulate(np.vstack((np.hstack(([[""]], classes)), np.hstack((classes.T, cm))))) @@ -30,6 +52,15 @@ def pretty_confusion_matrix(cm): def test_impact_nb_trees(tabnum): + """Evaluates and print the impact of the number of trees per forest on the classifiers performance + + Args: + tabnum (int): first confusion matrix numbering + Returns: + None + Raises: + None + """ nb_trees = [1, 5, 10] print("\nImpact of number of trees per forest") for n, p in enumerate(database.PROTOCOLS): @@ -43,6 +74,15 @@ def test_impact_nb_trees(tabnum): print(pretty_confusion_matrix(cm)) def test_impact_tree_depth(tabnum): + """Evaluates and print the impact of the trees depth on the classifiers performance + + Args: + tabnum (int): first confusion matrix numbering + Returns: + None + Raises: + None + """ depths = [1, 5, 10] print("\nImpact of trees maximum depth") for n, p in enumerate(database.PROTOCOLS): From 1066175c4c4670d38cb548eda5395c9d77e619b0 Mon Sep 17 00:00:00 2001 From: Amara Spano Date: Mon, 21 Sep 2020 01:05:23 +0200 Subject: [PATCH 3/3] clarified docstring for database.py --- scripts/database.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/database.py b/scripts/database.py index 0c068e5..04e89b7 100644 --- a/scripts/database.py +++ b/scripts/database.py @@ -139,7 +139,8 @@ def get(protocol, subset, classes=CLASSES, variables=VARIABLES, setname='csh101' variables (1d-array): list of desired variables (features) setname (str): name of the dataset to load Returns: - numpy.ndarray: array of arrays containing the samples corresponding to 1 class in order + numpy.ndarray: array of ordered arrays (of size n_sample x n_features) containing the samples corresponding to + 1 class Raises: None """