diff --git a/idtxl/network_analysis.py b/idtxl/network_analysis.py index f2cef002..e4329807 100644 --- a/idtxl/network_analysis.py +++ b/idtxl/network_analysis.py @@ -303,7 +303,8 @@ def _remove_forced_conditionals(self, candidate_set): cond = self.settings['add_conditionals'] if type(cond) is tuple: # easily add single variable cond = [cond] - candidate_set = list(set(candidate_set).difference(set(cond))) + cond_idx = self._lag_to_idx(cond) + candidate_set = list(set(candidate_set).difference(set(cond_idx))) return candidate_set def _append_selected_vars_idx(self, idx): diff --git a/test/test_active_information_storage.py b/test/test_active_information_storage.py index 5aeaddf7..ebbcc2b2 100644 --- a/test/test_active_information_storage.py +++ b/test/test_active_information_storage.py @@ -255,26 +255,27 @@ def test_define_candidates(): # Test if candidates that are added manually to the conditioning set are # removed from the candidate set. nw = ActiveInformationStorage() + nw.current_value = current_val settings = [ {'add_conditionals': None}, {'add_conditionals': (2, 3)}, - {'add_conditionals': [(2, 3), (4, 1)]}] + {'add_conditionals': [(2, 3), (4, 1)]}, + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' - - settings = [ - {'add_conditionals': [(1, 9)]}, - {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] - for s in settings: - nw.settings = s - candidates = nw._define_candidates(procs, samples) - assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' - assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' - assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + if s['add_conditionals'] is not None: + if type(s['add_conditionals']) is tuple: + cond_ind = nw._lag_to_idx([s['add_conditionals']]) + else: + cond_ind = nw._lag_to_idx(s['add_conditionals']) + for c in cond_ind: + assert c not in candidates, ( + 'Sample added erronously to candidates: {}.'.format(c)) if __name__ == '__main__': diff --git a/test/test_bivariate_mi.py b/test/test_bivariate_mi.py index e3562a84..e7535f28 100644 --- a/test/test_bivariate_mi.py +++ b/test/test_bivariate_mi.py @@ -372,26 +372,27 @@ def test_define_candidates(): # Test if candidates that are added manually to the conditioning set are # removed from the candidate set. nw = BivariateMI() + nw.current_value = current_val settings = [ {'add_conditionals': None}, {'add_conditionals': (2, 3)}, - {'add_conditionals': [(2, 3), (4, 1)]}] + {'add_conditionals': [(2, 3), (4, 1)]}, + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' - - settings = [ - {'add_conditionals': [(1, 9)]}, - {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] - for s in settings: - nw.settings = s - candidates = nw._define_candidates(procs, samples) - assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' - assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' - assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + if s['add_conditionals'] is not None: + if type(s['add_conditionals']) is tuple: + cond_ind = nw._lag_to_idx([s['add_conditionals']]) + else: + cond_ind = nw._lag_to_idx(s['add_conditionals']) + for c in cond_ind: + assert c not in candidates, ( + 'Sample added erronously to candidates: {}.'.format(c)) @jpype_missing def test_analyse_network(): diff --git a/test/test_bivariate_te.py b/test/test_bivariate_te.py index 8cc22f9c..2b89599f 100644 --- a/test/test_bivariate_te.py +++ b/test/test_bivariate_te.py @@ -372,26 +372,27 @@ def test_define_candidates(): # Test if candidates that are added manually to the conditioning set are # removed from the candidate set. nw = BivariateTE() + nw.current_value = current_val settings = [ {'add_conditionals': None}, {'add_conditionals': (2, 3)}, - {'add_conditionals': [(2, 3), (4, 1)]}] + {'add_conditionals': [(2, 3), (4, 1)]}, + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' - - settings = [ - {'add_conditionals': [(1, 9)]}, - {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] - for s in settings: - nw.settings = s - candidates = nw._define_candidates(procs, samples) - assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' - assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' - assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + if s['add_conditionals'] is not None: + if type(s['add_conditionals']) is tuple: + cond_ind = nw._lag_to_idx([s['add_conditionals']]) + else: + cond_ind = nw._lag_to_idx(s['add_conditionals']) + for c in cond_ind: + assert c not in candidates, ( + 'Sample added erronously to candidates: {}.'.format(c)) @jpype_missing diff --git a/test/test_multivariate_mi.py b/test/test_multivariate_mi.py index 83e9e0b0..a7d037d3 100644 --- a/test/test_multivariate_mi.py +++ b/test/test_multivariate_mi.py @@ -363,26 +363,27 @@ def test_define_candidates(): # Test if candidates that are added manually to the conditioning set are # removed from the candidate set. nw = MultivariateMI() + nw.current_value = current_val settings = [ {'add_conditionals': None}, {'add_conditionals': (2, 3)}, - {'add_conditionals': [(2, 3), (4, 1)]}] + {'add_conditionals': [(2, 3), (4, 1)]}, + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' - - settings = [ - {'add_conditionals': [(1, 9)]}, - {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] - for s in settings: - nw.settings = s - candidates = nw._define_candidates(procs, samples) - assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' - assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' - assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + if s['add_conditionals'] is not None: + if type(s['add_conditionals']) is tuple: + cond_ind = nw._lag_to_idx([s['add_conditionals']]) + else: + cond_ind = nw._lag_to_idx(s['add_conditionals']) + for c in cond_ind: + assert c not in candidates, ( + 'Sample added erronously to candidates: {}.'.format(c)) @jpype_missing diff --git a/test/test_multivariate_te.py b/test/test_multivariate_te.py index 731a8a45..33a5d33f 100644 --- a/test/test_multivariate_te.py +++ b/test/test_multivariate_te.py @@ -368,26 +368,27 @@ def test_define_candidates(): # Test if candidates that are added manually to the conditioning set are # removed from the candidate set. nw = MultivariateTE() + nw.current_value = current_val settings = [ {'add_conditionals': None}, {'add_conditionals': (2, 3)}, - {'add_conditionals': [(2, 3), (4, 1)]}] + {'add_conditionals': [(2, 3), (4, 1)]}, + {'add_conditionals': [(1, 9)]}, + {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' - - settings = [ - {'add_conditionals': [(1, 9)]}, - {'add_conditionals': [(1, 9), (2, 3), (4, 1)]}] - for s in settings: - nw.settings = s - candidates = nw._define_candidates(procs, samples) - assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' - assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' - assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' + if s['add_conditionals'] is not None: + if type(s['add_conditionals']) is tuple: + cond_ind = nw._lag_to_idx([s['add_conditionals']]) + else: + cond_ind = nw._lag_to_idx(s['add_conditionals']) + for c in cond_ind: + assert c not in candidates, ( + 'Sample added erronously to candidates: {}.'.format(c)) @jpype_missing