Skip to content

Commit

Permalink
Remove manually added conditioning variables from candidate set
Browse files Browse the repository at this point in the history
Remove manually added conditioning variables from candidate set. If
variables are added manually to the conditioning set via the
'add_conditionals' setting, they have to be removed from the candidate
set if both sets are not disjoint. Add unit tests.
  • Loading branch information
pwollstadt committed Aug 19, 2018
1 parent 634a769 commit 4b1b4a4
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 19 deletions.
37 changes: 35 additions & 2 deletions idtxl/network_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def _separate_realisations(self, idx_full, idx_single):
def _define_candidates(self, processes, samples):
"""Build a list of candidate indices.
Build a list of candidate indices. Note that variables that were
manually added to the conditioning set via the 'add_conditionals'
setting are removed from the candidate set if both sets are not
disjoint.
Args:
processes : list of int
process indices
Expand All @@ -268,9 +273,37 @@ def _define_candidates(self, processes, samples):
candidate and has the form (process index, sample index), indices
are absolute values with respect to some data array.
"""
candidate_set = []
candidate_set = self._build_variable_list(processes, samples)
# Remove candidates that were already manullay added to the
# conditioning set via the 'add_conditionals' setting. Otherwise the
# candidates get tested in the inclusion step.
candidate_set = self._remove_forced_conditionals(candidate_set)
return candidate_set

def _build_variable_list(self, processes, samples):
"""Build a list of variable tuples with (process index, sample index).
Args:
processes : list of int
process indices
samples: list of int
sample indices
Returns:
a list of variable tuples
"""
var_list = []
for idx in it.product(processes, samples):
candidate_set.append(idx)
var_list.append(idx)
return var_list

def _remove_forced_conditionals(self, candidate_set):
"""Remove enforced conditioning variables from candidate set."""
if self.settings['add_conditionals'] is not None:
cond = self.settings['add_conditionals']
if type(cond) is tuple: # easily add single variable
cond = [cond]
candidate_set = list(set(candidate_set).difference(set(cond)))
return candidate_set

def _append_selected_vars_idx(self, idx):
Expand Down
4 changes: 2 additions & 2 deletions idtxl/network_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def _force_conditionals(self, cond, data):
# that _define_candidates returns tuples with absolute indices and
# not lags.
if cond == 'faes':
cond = self._define_candidates(self.source_set,
[self.current_value[1]])
cond = self._build_variable_list(self.source_set,
[self.current_value[1]])
self._append_selected_vars(
cond,
data.get_realisations(self.current_value, cond)[0])
Expand Down
36 changes: 36 additions & 0 deletions test/test_active_information_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,43 @@ def test_discrete_input():
nw.analyse_single_process(settings=settings, data=data, process=0)


@jpype_missing
def test_define_candidates():
"""Test candidate definition from a list of procs and a list of samples."""
target = 1
tau_target = 3
max_lag_target = 10
current_val = (target, 10)
procs = [target]
samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_target,
-tau_target)
# Test if candidates that are added manually to the conditioning set are
# removed from the candidate set.
nw = ActiveInformationStorage()
settings = [
{'add_conditionals': None},
{'add_conditionals': (2, 3)},
{'add_conditionals': [(2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

settings = [
{'add_conditionals': [(1, 9)]},
{'add_conditionals': [(1, 9), (2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'


if __name__ == '__main__':
test_define_candidates()
test_return_local_values()
test_discrete_input()
test_analyse_network()
Expand Down
25 changes: 21 additions & 4 deletions test/test_bivariate_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,13 +368,30 @@ def test_define_candidates():
procs = [target]
samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources,
-tau_sources)
# Test if candidates that are added manually to the conditioning set are
# removed from the candidate set.
nw = BivariateMI()
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
settings = [
{'add_conditionals': None},
{'add_conditionals': (2, 3)},
{'add_conditionals': [(2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

settings = [
{'add_conditionals': [(1, 9)]},
{'add_conditionals': [(1, 9), (2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'


@jpype_missing
def test_analyse_network():
"""Test method for full network analysis."""
Expand Down Expand Up @@ -495,6 +512,7 @@ def test_indices_to_lags():


if __name__ == '__main__':
test_define_candidates()
test_zero_lag()
test_gauss_data()
test_return_local_values()
Expand All @@ -506,4 +524,3 @@ def test_indices_to_lags():
test_faes_method()
test_add_conditional_manually()
test_check_source_set()
test_define_candidates()
33 changes: 26 additions & 7 deletions test/test_bivariate_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,33 @@ def test_check_source_set():
def test_define_candidates():
"""Test candidate definition from a list of procs and a list of samples."""
target = 1
tau_target = 3
max_lag_target = 10
tau_sources = 3
max_lag_sources = 10
current_val = (target, 10)
procs = [target]
samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_target,
-tau_target)
samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources,
-tau_sources)
# Test if candidates that are added manually to the conditioning set are
# removed from the candidate set.
nw = BivariateTE()
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
settings = [
{'add_conditionals': None},
{'add_conditionals': (2, 3)},
{'add_conditionals': [(2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

settings = [
{'add_conditionals': [(1, 9)]},
{'add_conditionals': [(1, 9), (2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

Expand Down Expand Up @@ -476,7 +494,7 @@ def test_discrete_input():
@jpype_missing
def test_mute_data():
"""Test estimation from MuTE data."""
max_lag = 3
max_lag = 5
data = Data()
data.generate_mute_data(200, 5)
settings = {
Expand All @@ -487,6 +505,7 @@ def test_mute_data():
'n_perm_omnibus': 21,
'max_lag_sources': max_lag,
'min_lag_sources': 1,
'add_conditionals': [(1, 3), (1, 2)],
'max_lag_target': max_lag}
target = 2
te = BivariateTE()
Expand Down
22 changes: 20 additions & 2 deletions test/test_multivariate_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,27 @@ def test_define_candidates():
procs = [target]
samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources,
-tau_sources)
# Test if candidates that are added manually to the conditioning set are
# removed from the candidate set.
nw = MultivariateMI()
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
settings = [
{'add_conditionals': None},
{'add_conditionals': (2, 3)},
{'add_conditionals': [(2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

settings = [
{'add_conditionals': [(1, 9)]},
{'add_conditionals': [(1, 9), (2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

Expand Down
22 changes: 20 additions & 2 deletions test/test_multivariate_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,27 @@ def test_define_candidates():
procs = [target]
samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_target,
-tau_target)
# Test if candidates that are added manually to the conditioning set are
# removed from the candidate set.
nw = MultivariateTE()
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
settings = [
{'add_conditionals': None},
{'add_conditionals': (2, 3)},
{'add_conditionals': [(2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

settings = [
{'add_conditionals': [(1, 9)]},
{'add_conditionals': [(1, 9), (2, 3), (4, 1)]}]
for s in settings:
nw.settings = s
candidates = nw._define_candidates(procs, samples)
assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).'
assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

Expand Down

0 comments on commit 4b1b4a4

Please sign in to comment.