Skip to content

Commit

Permalink
docstring for database.py
Browse files Browse the repository at this point in the history
  • Loading branch information
spanoamara committed Sep 20, 2020
1 parent dd2a6b9 commit f08a0fa
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions scripts/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@


def load(setname='csh101'):
"""Loads the dataset
Args:
setname (str): name of the dataset to load
Returns:
dict of str : 2d-array: a dictionary mapping the classes names to their corresponding samples (1 row = 1 sample)
Raises:
None
"""
data = dict([(k, []) for k in CLASSES])
with open(os.path.join('../data', setname, '{}.ann.features.csv'.format(setname)), 'rt') as f:
reader = csv.reader(f, delimiter=',')
Expand All @@ -105,11 +114,35 @@ def load(setname='csh101'):


def split_data(data, subset, splits):
"""Splits the data set
Args:
data (dict of str : 2d-array): dataset to split
subset (str): subset to extract (train, validation or test)
splits (dict of str : tuple): a dictionary mapping the subsets to their range (from 0.0 to 1.0)
Returns:
dict of str : 2d-array: a dictionary mapping the classes names to their corresponding samples (1 row = 1 sample)
Raises:
None
"""
return dict([(k, data[k][range(int(splits[subset][0] * data[k].shape[0]),
int(splits[subset][1] * data[k].shape[0]))]) for k in data])


def get(protocol, subset, classes=CLASSES, variables=VARIABLES, setname='csh101'):
"""Get the desired subset
Args:
protocol (str): protocol to use
subset (str): subset to extract (train, validation or test)
classes (1d-array): list of desired classes
variables (1d-array): list of desired variables (features)
setname (str): name of the dataset to load
Returns:
numpy.ndarray: array of arrays containing the samples corresponding to 1 class in order
Raises:
None
"""
retval = split_data(load(setname), subset, PROTOCOLS[protocol])
varindex = [VARIABLES.index(k) for k in variables]
retval = dict([(k, retval[k][:, varindex]) for k in classes])
Expand Down

0 comments on commit f08a0fa

Please sign in to comment.