diff --git a/idtxl/network_inference.py b/idtxl/network_inference.py index e66af7ed..c672007b 100755 --- a/idtxl/network_inference.py +++ b/idtxl/network_inference.py @@ -332,7 +332,7 @@ def _initialise(self, settings, data, sources, target): '(''max_lag_sources'') needs to be specified.') if 'min_lag_sources' not in self.settings: raise RuntimeError('The minimum lag for source embedding ' - '(''max_lag_sources'') needs to be specified.') + '(''min_lag_sources'') needs to be specified.') self.settings.setdefault('max_lag_target', settings['max_lag_sources']) if (type(self.settings['min_lag_sources']) is not int or diff --git a/test/data/continuous_results_bmi_JidtGaussianCMI.p b/test/data/continuous_results_bmi_JidtGaussianCMI.p index 1c6bcc2c..90a6d59c 100644 Binary files a/test/data/continuous_results_bmi_JidtGaussianCMI.p and b/test/data/continuous_results_bmi_JidtGaussianCMI.p differ diff --git a/test/data/continuous_results_bmi_JidtKraskovCMI.p b/test/data/continuous_results_bmi_JidtKraskovCMI.p index a4293f14..f634b091 100644 Binary files a/test/data/continuous_results_bmi_JidtKraskovCMI.p and b/test/data/continuous_results_bmi_JidtKraskovCMI.p differ diff --git a/test/data/continuous_results_bte_JidtGaussianCMI.p b/test/data/continuous_results_bte_JidtGaussianCMI.p index 95159e21..a8717b56 100644 Binary files a/test/data/continuous_results_bte_JidtGaussianCMI.p and b/test/data/continuous_results_bte_JidtGaussianCMI.p differ diff --git a/test/data/continuous_results_bte_JidtKraskovCMI.p b/test/data/continuous_results_bte_JidtKraskovCMI.p index 88c63f5f..8067441f 100644 Binary files a/test/data/continuous_results_bte_JidtKraskovCMI.p and b/test/data/continuous_results_bte_JidtKraskovCMI.p differ diff --git a/test/data/continuous_results_mmi_JidtGaussianCMI.p b/test/data/continuous_results_mmi_JidtGaussianCMI.p index b6312c9a..29223010 100644 Binary files a/test/data/continuous_results_mmi_JidtGaussianCMI.p and b/test/data/continuous_results_mmi_JidtGaussianCMI.p differ diff --git a/test/data/continuous_results_mmi_JidtKraskovCMI.p b/test/data/continuous_results_mmi_JidtKraskovCMI.p index 497bba5f..1103c665 100644 Binary files a/test/data/continuous_results_mmi_JidtKraskovCMI.p and b/test/data/continuous_results_mmi_JidtKraskovCMI.p differ diff --git a/test/data/continuous_results_mte_JidtGaussianCMI.p b/test/data/continuous_results_mte_JidtGaussianCMI.p index a37ec566..dcb45e6c 100644 Binary files a/test/data/continuous_results_mte_JidtGaussianCMI.p and b/test/data/continuous_results_mte_JidtGaussianCMI.p differ diff --git a/test/data/continuous_results_mte_JidtKraskovCMI.p b/test/data/continuous_results_mte_JidtKraskovCMI.p index cd03933a..a3caeb84 100644 Binary files a/test/data/continuous_results_mte_JidtKraskovCMI.p and b/test/data/continuous_results_mte_JidtKraskovCMI.p differ diff --git a/test/data/discrete_results_bmi_JidtDiscreteCMI.p b/test/data/discrete_results_bmi_JidtDiscreteCMI.p index c5859517..fa6c10bd 100644 Binary files a/test/data/discrete_results_bmi_JidtDiscreteCMI.p and b/test/data/discrete_results_bmi_JidtDiscreteCMI.p differ diff --git a/test/data/discrete_results_bte_JidtDiscreteCMI.p b/test/data/discrete_results_bte_JidtDiscreteCMI.p index 072a17cd..514f0ff3 100644 Binary files a/test/data/discrete_results_bte_JidtDiscreteCMI.p and b/test/data/discrete_results_bte_JidtDiscreteCMI.p differ diff --git a/test/data/discrete_results_mmi_JidtDiscreteCMI.p b/test/data/discrete_results_mmi_JidtDiscreteCMI.p index 1be8c0e5..a6faa674 100644 Binary files a/test/data/discrete_results_mmi_JidtDiscreteCMI.p and b/test/data/discrete_results_mmi_JidtDiscreteCMI.p differ diff --git a/test/data/discrete_results_mte_JidtDiscreteCMI.p b/test/data/discrete_results_mte_JidtDiscreteCMI.p index 54e35a36..8d809de1 100644 Binary files a/test/data/discrete_results_mte_JidtDiscreteCMI.p and b/test/data/discrete_results_mte_JidtDiscreteCMI.p differ diff --git a/test/data/mute_results_0.p b/test/data/mute_results_0.p index 4a4a6e82..dfb5572d 100644 Binary files a/test/data/mute_results_0.p and b/test/data/mute_results_0.p differ diff --git a/test/data/mute_results_1.p b/test/data/mute_results_1.p index 0f9de355..c0020f20 100644 Binary files a/test/data/mute_results_1.p and b/test/data/mute_results_1.p differ diff --git a/test/data/mute_results_2.p b/test/data/mute_results_2.p index 219f9c93..c58980b3 100644 Binary files a/test/data/mute_results_2.p and b/test/data/mute_results_2.p differ diff --git a/test/data/mute_results_3.p b/test/data/mute_results_3.p index 70656858..a17374e8 100644 Binary files a/test/data/mute_results_3.p and b/test/data/mute_results_3.p differ diff --git a/test/data/mute_results_4.p b/test/data/mute_results_4.p index 0f9de355..c0020f20 100644 Binary files a/test/data/mute_results_4.p and b/test/data/mute_results_4.p differ diff --git a/test/data/mute_results_full.p b/test/data/mute_results_full.p index 43137ee5..ff70fbe1 100644 Binary files a/test/data/mute_results_full.p and b/test/data/mute_results_full.p differ diff --git a/test/generate_test_data.py b/test/generate_test_data.py index bc831c3e..a9a37177 100644 --- a/test/generate_test_data.py +++ b/test/generate_test_data.py @@ -215,7 +215,7 @@ def _print_result(res): if __name__ == '__main__': - analyse_mute_te_data() analyse_discrete_data() + analyse_mute_te_data() analyse_continuous_data() assert_results() diff --git a/test/systemtest_multivariate_te.py b/test/systemtest_multivariate_te.py index 22f808bc..9243a3f9 100644 --- a/test/systemtest_multivariate_te.py +++ b/test/systemtest_multivariate_te.py @@ -105,7 +105,7 @@ def test_multivariate_te_lagged_copies(): results = random_analysis.analyse_single_target(settings, data, t) assert len(results.get_single_target(t, fdr=False).selected_vars_full) == 1, ( 'Conditional contains more/less than 1 variables.') - assert not results.get_single_target(t, fdr=False).selected_vars_sources.size, ( + assert not results.get_single_target(t, fdr=False).selected_vars_sources, ( 'Conditional sources is not empty.') assert len(results.get_single_target(t, fdr=False).selected_vars_target) == 1, ( 'Conditional target contains more/less than 1 variable.') @@ -136,6 +136,7 @@ def test_multivariate_te_random(): settings = { 'cmi_estimator': 'JidtKraskovCMI', 'max_lag_sources': 5, + 'min_lag_sources': 1, 'n_perm_max_stat': 200, 'n_perm_min_stat': 200, 'n_perm_omnibus': 500, @@ -149,7 +150,7 @@ def test_multivariate_te_random(): results = random_analysis.analyse_single_target(settings, data, t) assert len(results.get_single_target(t, fdr=False).selected_vars_full) == 1, ( 'Conditional contains more/less than 1 variables.') - assert not results.get_single_target(t, fdr=False).selected_vars_sources.size, ( + assert not results.get_single_target(t, fdr=False).selected_vars_sources, ( 'Conditional sources is not empty.') assert len(results.get_single_target(t, fdr=False).selected_vars_target) == 1, ( 'Conditional target contains more/less than 1 variable.') @@ -208,7 +209,7 @@ def test_multivariate_te_lorenz_2(): # Just analyse the direction of coupling results = lorenz_analysis.analyse_single_target(settings, data, target=1) print(results._single_target) - print(results.get_adjacency_matrix('binary')) + print(results.get_adjacency_matrix('binary', fdr=False)) def test_multivariate_te_mute(): @@ -272,6 +273,7 @@ def test_multivariate_te_multiple_runs(): test_multivariate_te_mute() test_multivariate_te_lorenz_2() test_multivariate_te_random() + test_multivariate_te_lagged_copies() test_multivariate_te_multiple_runs() test_multivariate_te_corr_gaussian() test_multivariate_te_corr_gaussian('OpenCLKraskovCMI') diff --git a/test/systemtest_multivariate_te_discrete.py b/test/systemtest_multivariate_te_discrete.py index 3d018550..a19431f6 100644 --- a/test/systemtest_multivariate_te_discrete.py +++ b/test/systemtest_multivariate_te_discrete.py @@ -7,7 +7,7 @@ from idtxl.idtxl_utils import calculate_mi -def test_multivariate_te_corr_gaussian(): +def test_multivariate_te_corr_gaussian(estimator=None): """Test multivariate TE estimation on correlated Gaussians. Run the multivariate TE algorithm on two sets of random Gaussian data with @@ -25,6 +25,9 @@ def test_multivariate_te_corr_gaussian(): This test runs considerably faster than other system tests. This produces strange small values for non-coupled sources. TODO """ + if estimator is None: + estimator = 'JidtKraskovCMI' + cov = 0.4 expected_mi, source1, source2, target = _get_gauss_data(covariance=cov) # n = 1000 @@ -39,7 +42,7 @@ def test_multivariate_te_corr_gaussian(): data = Data(normalise=True) data.set_data(np.vstack((source1[1:].T, target[:-1].T)), 'ps') settings = { - 'cmi_estimator': 'JidtDiscreteCMI', + 'cmi_estimator': estimator, 'discretise_method': 'max_ent', 'max_lag_sources': 5, 'min_lag_sources': 1, @@ -120,7 +123,7 @@ def test_multivariate_te_lagged_copies(): results = random_analysis.analyse_single_target(settings, data, t) assert len(results.get_single_target(t, fdr=False).selected_vars_full) == 1, ( 'Conditional contains more/less than 1 variables.') - assert not results.get_single_target(t, fdr=False).selected_vars_sources.size, ( + assert not results.get_single_target(t, fdr=False).selected_vars_sources, ( 'Conditional sources is not empty.') assert len(results.get_single_target(t, fdr=False).selected_vars_target) == 1, ( 'Conditional target contains more/less than 1 variable.') @@ -151,6 +154,7 @@ def test_multivariate_te_random(): settings = { 'cmi_estimator': 'JidtDiscreteCMI', 'discretise_method': 'max_ent', + 'min_lag_sources': 1, 'max_lag_sources': 5, 'n_perm_max_stat': 200, 'n_perm_min_stat': 200, @@ -165,7 +169,7 @@ def test_multivariate_te_random(): results = random_analysis.analyse_single_target(settings, data, t) assert len(results.get_single_target(t, fdr=False).selected_vars_full) == 1, ( 'Conditional contains more/less than 1 variables.') - assert not results.get_single_target(t, fdr=False).selected_vars_sources.size, ( + assert not results.get_single_target(t, fdr=False).selected_vars_sources, ( 'Conditional sources is not empty.') assert len(results.get_single_target(t, fdr=False).selected_vars_target) == 1, ( 'Conditional target contains more/less than 1 variable.') @@ -223,7 +227,7 @@ def test_multivariate_te_lorenz_2(): # Just analyse the coupled direction results = lorenz_analysis.analyse_single_target(settings, data, 1) print(results._single_target) - print(results.adjacency_matrix) + print(results.get_adjacency_matrix(weights='binary', fdr=False)) def test_multivariate_te_mute(): @@ -263,19 +267,29 @@ def test_multivariate_te_mute(): results_eq = network_analysis.analyse_network(settings, data, targets=[1, 2]) - assert (np.isclose( - results_eq.get_single_target(1, fdr=False).omnibus_te, - results_me.get_single_target(1, fdr=False).omnibus_te, rtol=0.05)), ( - 'TE into first target is not equal for both binning methods.') - assert (np.isclose( - results_eq.get_single_target(2, fdr=False).omnibus_te, - results_me.get_single_target(2, fdr=False).omnibus_te, rtol=0.05)), ( - 'TE into second target is not equal for both binning methods.') + for t in [1, 2]: + print('Target {0}: equal binning: {1}, max. ent. binning: {2}'.format( + t, + results_eq.get_single_target(t, fdr=False).omnibus_te, + results_me.get_single_target(t, fdr=False).omnibus_te + )) + # Skip comparison of estimates if analyses returned different source + # sets. This will always lead to different estimates. + if (results_eq.get_single_target(t, fdr=False).selected_vars_sources == + results_me.get_single_target(t, fdr=False).selected_vars_sources): + assert (np.isclose( + results_eq.get_single_target(1, fdr=False).omnibus_te, + results_me.get_single_target(1, fdr=False).omnibus_te, + rtol=0.05)), ('Target {0}: unequl results for both binning ' + 'methods.'.format(t)) + else: + continue if __name__ == '__main__': test_multivariate_te_lorenz_2() test_multivariate_te_mute() test_multivariate_te_random() + test_multivariate_te_lagged_copies() test_multivariate_te_corr_gaussian() test_multivariate_te_corr_gaussian('OpenCLKraskovCMI') diff --git a/test/systemtest_mute.py b/test/systemtest_mute.py index be88bb21..190c1632 100644 --- a/test/systemtest_mute.py +++ b/test/systemtest_mute.py @@ -1,4 +1,3 @@ -import os import pickle import time from idtxl.multivariate_te import MultivariateTE @@ -22,5 +21,5 @@ runtime = time.time() - start_time print("---- {0} minutes".format(runtime / 60)) -path = '{0}output/'.format(os.path.dirname(__file__)) -pickle.dump(results, open('{0}test_mute_res_{1}'.format(path, 0), 'wb')) +# Save results +# pickle.dump(results, open('test_mute_results.p', 'wb')) diff --git a/test/systemtest_network_comparison.py b/test/systemtest_network_comparison.py index 3b91f3a2..cd30893d 100644 --- a/test/systemtest_network_comparison.py +++ b/test/systemtest_network_comparison.py @@ -85,24 +85,23 @@ def test_network_comparison(): def _verify_test(c_within, c_between, res): # Test values for verification - p = 0.25 # max. attainable p-value - tp = res.adjacency_matrix > 0 # get true positives - print(c_within.adjacency_matrix_union) - assert (c_between.adjacency_matrix_union[tp] == 1).all(), ( + tp = res.get_adjacency_matrix('binary') > 0 # get true positives + print(c_within.get_adjacency_matrix('union')) + assert (c_between.get_adjacency_matrix('union')[tp] == 1).all(), ( 'Missing union link in between network comparison.') - assert (c_within.adjacency_matrix_union[tp] == 1).all(), ( + assert (c_within.get_adjacency_matrix('union')[tp] == 1).all(), ( 'Missing union link in wihin network comparison.') - assert (c_between.adjacency_matrix_pvalue[tp] < 1).all(), ( + assert (c_between.get_adjacency_matrix('pvalue')[tp] < 1).all(), ( 'Wrong p-value in between network comparison.') - assert (c_within.adjacency_matrix_pvalue[tp] < 1).all(), ( + assert (c_within.get_adjacency_matrix('pvalue')[tp] < 1).all(), ( 'Wrong p-value in wihin network comparison.') # assert (c_between.adjacency_matrix_comparison[tp]).all(), ( # 'Wrong comparison in between network comparison.') # assert (c_within.adjacency_matrix_comparison[tp]).all(), ( # 'Wrong comparison in wihin network comparison.') - assert (c_between.adjacency_matrix_diff_abs[tp] > 0).all(), ( + assert (c_between.get_adjacency_matrix('diff_abs')[tp] > 0).all(), ( 'Missed difference in between network comparison.') - assert (c_within.adjacency_matrix_diff_abs[tp] > 0).all(), ( + assert (c_within.get_adjacency_matrix('diff_abs')[tp] > 0).all(), ( 'Missed difference in wihin network comparison.') diff --git a/test/systemtest_partial_information_decomposition.py b/test/systemtest_partial_information_decomposition.py index 09b3c8a4..7e5b932a 100644 --- a/test/systemtest_partial_information_decomposition.py +++ b/test/systemtest_partial_information_decomposition.py @@ -9,7 +9,7 @@ def test_pid_xor_data(): """Test basic calls to PID class.""" - n = 20 + n = 100 alph = 2 x = np.random.randint(0, alph, n) y = np.random.randint(0, alph, n) @@ -43,9 +43,9 @@ def test_pid_xor_data(): t_sydney = tm.time() - tic print('\nResults Tartu estimator:') - utils.print_dict(est_tartu) + utils.print_dict(est_tartu.get_single_target(2)) print('\nResults Sydney estimator:') - utils.print_dict(est_sydney) + utils.print_dict(est_sydney.get_single_target(2)) print('\nLogical XOR') print('Estimator Sydney\t\tTartu\n') diff --git a/test/systemtest_pid_sydney.py b/test/systemtest_pid_sydney.py index 3cafb80d..dfd51009 100644 --- a/test/systemtest_pid_sydney.py +++ b/test/systemtest_pid_sydney.py @@ -5,7 +5,7 @@ import time as tm from bitstring import BitArray, Bits # import estimators_fast_pid as epid -from idtxl.set_estimator import Estimator_pid +from idtxl.estimators_pid import SydneyPID # LOGICAL AND alph_x = 2 @@ -18,19 +18,17 @@ y = np.random.randint(0, alph_y, n) z = np.logical_and(x, y).astype(int) -cfg = { +settings = { 'alph_s1': alph_x, 'alph_s2': alph_y, 'alph_t': alph_z, 'max_unsuc_swaps_row_parm': 3, 'num_reps': 63, - 'max_iters': 10000 -} - -pid_sydney = Estimator_pid('pid_sydney') + 'max_iters': 10000} tic = tm.time() -est = pid_sydney.estimate(x, y, z, cfg) +pid_sydney = SydneyPID(settings) +est = pid_sydney.estimate(x, y, z) toc = tm.time() print('\n\nLOGICAL AND') @@ -43,7 +41,7 @@ # LOGICAL XOR z = np.logical_xor(x, y).astype(int) -cfg = { +settings = { 'alph_s1': alph_x, 'alph_s2': alph_y, 'alph_t': alph_z, @@ -53,7 +51,8 @@ } tic = tm.time() -est = pid_sydney.estimate(x, y, z, cfg) +pid_sydney = SydneyPID(settings) +est = pid_sydney.estimate(x, y, z) toc = tm.time() print('\nPID evaluation {:.3f} seconds\n'.format(toc - tic)) @@ -65,7 +64,7 @@ # SINGLE INPUT COPY z = x -cfg = { +settings = { 'alph_s1': alph_x, 'alph_s2': alph_y, 'alph_t': alph_z, @@ -75,7 +74,8 @@ } tic = tm.time() -est = pid_sydney.estimate(x, y, z, cfg) +pid_sydney = SydneyPID(settings) +est = pid_sydney.estimate(x, y, z) toc = tm.time() print('\nPID evaluation {:.3f} seconds\n'.format(toc - tic)) @@ -102,6 +102,7 @@ def parity(bytestring): return par + x = np.zeros((n,), dtype=np.int) y = np.zeros((n,), dtype=np.int) z = np.zeros((n,), dtype=np.int) @@ -115,7 +116,7 @@ def parity(bytestring): alph_y = 4 alph_z = 2 -cfg = { +settings = { 'alph_s1': alph_x, 'alph_s2': alph_y, 'alph_t': alph_z, @@ -125,7 +126,8 @@ def parity(bytestring): } tic = tm.time() -est = pid_sydney.estimate(x, y, z, cfg) +pid_sydney = SydneyPID(settings) +est = pid_sydney.estimate(x, y, z) toc = tm.time() print('\n\nPARITY') diff --git a/test/systemtest_visualise_graph.py b/test/systemtest_visualise_graph.py index 2be9889c..2743f88a 100644 --- a/test/systemtest_visualise_graph.py +++ b/test/systemtest_visualise_graph.py @@ -30,13 +30,22 @@ def test_visualise_multivariate_te(): results = network_analysis.analyse_network(settings, data, targets=[0, 1, 2]) # generate graph plots - visualise_graph.plot_selected_vars(results, target=1, sign_sources=False) - plt.show() - visualise_graph.plot_network(results, fdr=False) - plt.show() - visualise_graph.plot_network(results, fdr=True) + try: + visualise_graph.plot_selected_vars( + results, target=1, sign_sources=False, fdr=True) + plt.show() + except RuntimeError: + print('No FDR-corrected results.') + try: + visualise_graph.plot_network(results, weights='binary', fdr=True) + plt.show() + except RuntimeError: + print('No FDR-corrected results.') + + visualise_graph.plot_network(results, weights='binary', fdr=False) plt.show() - visualise_graph.plot_selected_vars(results, target=1, sign_sources=True) + visualise_graph.plot_selected_vars( + results, target=1, sign_sources=True, fdr=False) plt.show()