Skip to content

Commit

Permalink
Merge pull request #17 from sdevenes/feature/code_documentation
Browse files Browse the repository at this point in the history
Feature/code documentation
  • Loading branch information
sdevenes authored Sep 21, 2020
2 parents f3bf608 + 1066175 commit 42bbf0f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
34 changes: 34 additions & 0 deletions scripts/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@


def load(setname='csh101'):
"""Loads the dataset
Args:
setname (str): name of the dataset to load
Returns:
dict of str : 2d-array: a dictionary mapping the classes names to their corresponding samples (1 row = 1 sample)
Raises:
None
"""
data = dict([(k, []) for k in CLASSES])
with open(os.path.join('../data', setname, '{}.ann.features.csv'.format(setname)), 'rt') as f:
reader = csv.reader(f, delimiter=',')
Expand All @@ -105,11 +114,36 @@ def load(setname='csh101'):


def split_data(data, subset, splits):
"""Splits the data set
Args:
data (dict of str : 2d-array): dataset to split
subset (str): subset to extract (train, validation or test)
splits (dict of str : tuple): a dictionary mapping the subsets to their range (from 0.0 to 1.0)
Returns:
dict of str : 2d-array: a dictionary mapping the classes names to their corresponding samples (1 row = 1 sample)
Raises:
None
"""
return dict([(k, data[k][range(int(splits[subset][0] * data[k].shape[0]),
int(splits[subset][1] * data[k].shape[0]))]) for k in data])


def get(protocol, subset, classes=CLASSES, variables=VARIABLES, setname='csh101'):
"""Get the desired subset
Args:
protocol (str): protocol to use
subset (str): subset to extract (train, validation or test)
classes (1d-array): list of desired classes
variables (1d-array): list of desired variables (features)
setname (str): name of the dataset to load
Returns:
numpy.ndarray: array of ordered arrays (of size n_sample x n_features) containing the samples corresponding to
1 class
Raises:
None
"""
retval = split_data(load(setname), subset, PROTOCOLS[protocol])
varindex = [VARIABLES.index(k) for k in variables]
retval = dict([(k, retval[k][:, varindex]) for k in classes])
Expand Down
40 changes: 40 additions & 0 deletions scripts/paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@


def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_depth=10):
"""Basic test for the random forest classifier
Args:
protocol (str): protocol to use
variables (1d-array): list of desired variables (features)
setname (str): name of the dataset to load
nb_tree_per_forest: number of decision trees in the forest
max_depth: max depth of the trees
Returns:
numpy.ndarray: A 2D array (with a dtype of int) containing the confusion matrix.
Raises:
None
"""
# get train data
train = database.get(protocol, 'train', database.CLASSES, variables, setname)
# make and train model
Expand All @@ -23,13 +36,31 @@ def base_test(protocol, variables, setname='csh101', nb_tree_per_forest=50, max_


def pretty_confusion_matrix(cm):
"""Adds labels to confusion matrix
Args:
cm (numpy.ndarray): A 2D array (with a dtype of int) containing the confusion matrix.
Returns:
str: nicely formatted confusion matrix for printing
Raises:
None
"""
classes = np.array([database.CLASSES])
table = tabulate(np.vstack((np.hstack(([[""]], classes)),
np.hstack((classes.T, cm)))))
return table


def test_impact_nb_trees(tabnum):
"""Evaluates and print the impact of the number of trees per forest on the classifiers performance
Args:
tabnum (int): first confusion matrix numbering
Returns:
None
Raises:
None
"""
nb_trees = [1, 5, 10]
print("\nImpact of number of trees per forest")
for n, p in enumerate(database.PROTOCOLS):
Expand All @@ -43,6 +74,15 @@ def test_impact_nb_trees(tabnum):
print(pretty_confusion_matrix(cm))

def test_impact_tree_depth(tabnum):
"""Evaluates and print the impact of the trees depth on the classifiers performance
Args:
tabnum (int): first confusion matrix numbering
Returns:
None
Raises:
None
"""
depths = [1, 5, 10]
print("\nImpact of trees maximum depth")
for n, p in enumerate(database.PROTOCOLS):
Expand Down

0 comments on commit 42bbf0f

Please sign in to comment.