diff --git a/docs/conf.py b/docs/conf.py
index 7150dba7..51aa3426 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -69,9 +69,9 @@
# built documents.
#
# The short X.Y version.
-version = u'1.0'
+version = u'1.2.1'
# The full version, including alpha/beta/rc tags.
-release = u'1.0'
+release = u'1.2.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
diff --git a/docs/doctrees/environment.pickle b/docs/doctrees/environment.pickle
index 299d9755..51653add 100644
Binary files a/docs/doctrees/environment.pickle and b/docs/doctrees/environment.pickle differ
diff --git a/docs/doctrees/idtxl.doctree b/docs/doctrees/idtxl.doctree
index 96f6b0b3..0815cc6d 100644
Binary files a/docs/doctrees/idtxl.doctree and b/docs/doctrees/idtxl.doctree differ
diff --git a/docs/doctrees/idtxl_data_class.doctree b/docs/doctrees/idtxl_data_class.doctree
index 03ca570b..b27b8737 100644
Binary files a/docs/doctrees/idtxl_data_class.doctree and b/docs/doctrees/idtxl_data_class.doctree differ
diff --git a/docs/doctrees/idtxl_estimators.doctree b/docs/doctrees/idtxl_estimators.doctree
index fa56c64a..ec154899 100644
Binary files a/docs/doctrees/idtxl_estimators.doctree and b/docs/doctrees/idtxl_estimators.doctree differ
diff --git a/docs/doctrees/idtxl_helper.doctree b/docs/doctrees/idtxl_helper.doctree
index 06e90ff1..c37d1057 100644
Binary files a/docs/doctrees/idtxl_helper.doctree and b/docs/doctrees/idtxl_helper.doctree differ
diff --git a/docs/doctrees/idtxl_network_comparison.doctree b/docs/doctrees/idtxl_network_comparison.doctree
index a8aaf180..4a422b55 100644
Binary files a/docs/doctrees/idtxl_network_comparison.doctree and b/docs/doctrees/idtxl_network_comparison.doctree differ
diff --git a/docs/doctrees/idtxl_network_inference.doctree b/docs/doctrees/idtxl_network_inference.doctree
index 780a140f..b302e085 100644
Binary files a/docs/doctrees/idtxl_network_inference.doctree and b/docs/doctrees/idtxl_network_inference.doctree differ
diff --git a/docs/doctrees/idtxl_process_analysis.doctree b/docs/doctrees/idtxl_process_analysis.doctree
index 1e5b94b3..228a5574 100644
Binary files a/docs/doctrees/idtxl_process_analysis.doctree and b/docs/doctrees/idtxl_process_analysis.doctree differ
diff --git a/docs/doctrees/idtxl_results_class.doctree b/docs/doctrees/idtxl_results_class.doctree
index 552ddf56..a9dc0634 100644
Binary files a/docs/doctrees/idtxl_results_class.doctree and b/docs/doctrees/idtxl_results_class.doctree differ
diff --git a/docs/doctrees/index.doctree b/docs/doctrees/index.doctree
index 217ea2f4..9d41da1d 100644
Binary files a/docs/doctrees/index.doctree and b/docs/doctrees/index.doctree differ
diff --git a/docs/doctrees/modules.doctree b/docs/doctrees/modules.doctree
new file mode 100644
index 00000000..9c7aa83b
Binary files /dev/null and b/docs/doctrees/modules.doctree differ
diff --git a/docs/html/.buildinfo b/docs/html/.buildinfo
index 3213255e..89ff4c30 100644
--- a/docs/html/.buildinfo
+++ b/docs/html/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 6c4e3c84cb60f8ef29d1bd405463eb1f
+config: f717522005f8a2bcad31d42c004451be
tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/docs/html/_modules/idtxl/active_information_storage.html b/docs/html/_modules/idtxl/active_information_storage.html
index fd3e7029..a3dd1294 100644
--- a/docs/html/_modules/idtxl/active_information_storage.html
+++ b/docs/html/_modules/idtxl/active_information_storage.html
@@ -1,18 +1,17 @@
-
+
-
+
raiseValueError('Processes were not specified correctly: ''{0}.'.format(processes))
+ # Check and set defaults for checkpointing.
+ self.settings=self._set_checkpointing_defaults(
+ settings,data,[],processes)
+
# Perform AIS estimation for each target individually.results=ResultsSingleProcessAnalysis(n_nodes=data.n_processes,
@@ -242,6 +245,12 @@
Source code for idtxl.active_information_storage
further settings (default=False) - verbose : bool [optional] - toggle console output (default=True)
+ - write_ckp : bool [optional] - enable checkpointing, writes
+ analysis state to disk every time a variable is selected;
+ resume crashed analysis using
+ network_analysis.resume_checkpoint() (default=False)
+ - filename_ckp : str [optional] - checkpoint file name (without
+ extension) (default='./idtxl_checkpoint') data : Data instance raw data for analysis
@@ -339,6 +348,10 @@
Source code for idtxl.active_information_storage
# user. This tests if there is sufficient data to do all tests.# surrogates.check_permutations(self, data)
+ # Check and set defaults for checkpointing.
+ self.settings=self._set_checkpointing_defaults(
+ self.settings,data,[],process)
+
# Reset all attributes to inital values if the instance has been used# before.ifself.selected_vars_full:
@@ -389,8 +402,8 @@
Source code for idtxl.active_information_storage
"""success=Falseifself.settings['verbose']:
- print('testing candidate set: {0}'.format(
- self._idx_to_lag(candidate_set)))
+ print('testing candidate set: {0}'.format(
+ self._idx_to_lag(candidate_set)))whilecandidate_set:# Get realisations for all candidates.cand_real=data.get_realisations(self.current_value,
@@ -410,12 +423,12 @@
Source code for idtxl.active_information_storage
# we'll terminate the search for more candidates,# though those identified already remain validprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting current estimation set.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
-
+
# Test max CMI for significance with maximum statistics.te_max_candidate=max(temp_te)max_candidate=candidate_set[np.argmax(temp_te)]
@@ -424,19 +437,20 @@
Source code for idtxl.active_information_storage
self._idx_to_lag([max_candidate])[0]),end='')significant=Falsetry:
- significant=stats.max_statistic(self,data,candidate_set,
- te_max_candidate)[0]
+ significant=stats.max_statistic(
+ self,data,candidate_set,te_max_candidate,
+ conditional=self._selected_vars_realisations)[0]exceptex.AlgorithmExhaustedErrorasaee:# The algorithm cannot continue here, so# we'll terminate the check on the max stats and not let the# source passprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting max stats and further selection for target.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
-
+
# If the max is significant keep it and test the next candidate. If# it is not significant break. There will be no further significant# sources b/c they all have lesser TE.
@@ -451,11 +465,12 @@
var2=self._current_value_realisations,conditional=conditional_realisations)exceptex.AlgorithmExhaustedErrorasaee:
- # The algorithm cannot continue here, so
- # we'll terminate the pruning check,
- # assuming that we need not prune any more
- print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
- print('Halting current pruning and allowing others to'
- ' remain.')
- # For now we don't need a stack trace:
- # traceback.print_tb(aee.__traceback__)
- break
+ # The algorithm cannot continue here, so we'll terminate the
+ # pruning check, assuming that we need not prune any more
+ print('AlgorithmExhaustedError encountered in '
+ 'estimations: '+aee.message)
+ print('Halting current pruning and allowing others to'
+ ' remain.')
+ # For now we don't need a stack trace:
+ # traceback.print_tb(aee.__traceback__)
+ break# Test min TE for significance with minimum statistics.te_min_candidate=min(temp_te)min_candidate=self.selected_vars_sources[np.argmin(temp_te)]ifself.settings['verbose']:
- print('{0}'.format(self._idx_to_lag([min_candidate])[0]))
+ print('testing candidate: {0}'.format(
+ self._idx_to_lag([min_candidate])[0]))
+ remaining_candidates=set(self.selected_vars_sources).difference(
+ set([min_candidate]))
+ conditional_realisations=data.get_realisations(
+ self.current_value,remaining_candidates)[0]try:[significant,p,surr_table]=stats.min_statistic(
- self,data,
- self.selected_vars_sources,
- te_min_candidate)
+ analysis_setup=self,
+ data=data,
+ candidate_set=self.selected_vars_sources,
+ te_min_candidate=te_min_candidate,
+ conditional=conditional_realisations)exceptex.AlgorithmExhaustedErrorasaee:# The algorithm cannot continue here, so# we'll terminate the min statistics# assuming that we need not prune any moreprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting current pruning and allowing others to'
- ' remain.')
+ ' remain.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
@@ -554,6 +575,8 @@
Source code for idtxl.active_information_storage
# if self.settings['verbose']:# print(' -- not significant')self._remove_selected_var(min_candidate)
+ ifself.settings['write_ckp']:
+ self._write_checkpoint()else:ifself.settings['verbose']:print(' -- significant')
@@ -572,8 +595,9 @@
Source code for idtxl.active_information_storage
# The algorithm cannot continue here, so# we'll set the results to zeroprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
- print('Halting AIS final conditional test and setting to not significant.')
+ 'estimations: '+aee.message)
+ print('Halting AIS final conditional test and setting to not '
+ 'significant.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)ais=0
@@ -598,15 +622,15 @@
Source code for idtxl.active_information_storage
# The algorithm cannot continue here, so# we'll set the results to zeroprint('AlgorithmExhaustedError encountered in '
- 'final local AIS estimations: '+aee.message)
+ 'final local AIS estimations: '+aee.message)print('Setting all local results to zero (but leaving'
- ' surrogate statistical test results)')
+ ' surrogate statistical test results)')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)# Return local AIS values of all zeros:# (length gleaned from line below)local_ais=np.zeros(
- (max(replication_ind)+1)*sum(replication_ind==0));
+ (max(replication_ind)+1)*sum(replication_ind==0))# Reshape local AIS to a [replications x samples] matrix.self.ais=local_ais.reshape(
@@ -642,23 +666,22 @@
assert(len(sources)==len(targets)),('List of targets and list of sources have to have the same length')
+ # Check and set defaults for checkpointing. If requested, initialise
+ # checkpointing.
+ self.settings=self._set_checkpointing_defaults(
+ settings,data,sources,targets)
+
# Perform MI estimation for each target individuallyresults=ResultsNetworkInference(n_nodes=data.n_processes,n_realisations=data.n_realisations(),
@@ -289,6 +294,12 @@
Source code for idtxl.bivariate_mi
further settings (default=False) - verbose : bool [optional] - toggle console output (default=True)
+ - write_ckp : bool [optional] - enable checkpointing, writes
+ analysis state to disk every time a variable is selected;
+ resume crashed analysis using
+ network_analysis.resume_checkpoint() (default=False)
+ - filename_ckp : str [optional] - checkpoint file name (without
+ extension) (default='./idtxl_checkpoint') data : Data instance raw data for analysis
@@ -352,23 +363,22 @@
+"""Estimate partial information decomposition (PID).
+
+Estimate PID for two source and one target process using different estimators.
+
+Note:
+ Written for Python 3.4+
+"""
+importnumpyasnp
+from.single_process_analysisimportSingleProcessAnalysis
+from.estimatorimportfind_estimator
+from.resultsimportResultsPID
+
+
+
[docs]classBivariatePID(SingleProcessAnalysis):
+ """Perform partial information decomposition for individual processes.
+
+ Perform partial information decomposition (PID) for two source processes
+ and one target process in the network. Estimate unique, shared, and
+ synergistic information in the two sources about the target. Call
+ analyse_network() on the whole network or a set of nodes or call
+ analyse_single_target() to estimate PID for a single process. See
+ docstrings of the two functions for more information.
+
+ References:
+
+ - Williams, P. L., & Beer, R. D. (2010). Nonnegative Decomposition of
+ Multivariate Information, 1–14. Retrieved from
+ http://arxiv.org/abs/1004.2515
+ - Bertschinger, N., Rauh, J., Olbrich, E., Jost, J., & Ay, N. (2014).
+ Quantifying Unique Information. Entropy, 16(4), 2161–2183.
+ http://doi.org/10.3390/e16042161
+
+ Attributes:
+ target : int
+ index of target process
+ sources : array type
+ pair of indices of source processes
+ settings : dict
+ analysis settings
+ results : dict
+ estimated PID
+ """
+
+ def__init__(self):
+ super().__init__()
+
+
[docs]defanalyse_network(self,settings,data,targets,sources):
+ """Estimate partial information decomposition for network nodes.
+
+ Estimate partial information decomposition (PID) for multiple nodes in
+ the network.
+
+ Note:
+ For a detailed description of the algorithm and settings see
+ documentation of the analyse_single_target() method and
+ references in the class docstring.
+
+ Example:
+
+ >>> n = 20
+ >>> alph = 2
+ >>> x = np.random.randint(0, alph, n)
+ >>> y = np.random.randint(0, alph, n)
+ >>> z = np.logical_xor(x, y).astype(int)
+ >>> data = Data(np.vstack((x, y, z)), 'ps', normalise=False)
+ >>> settings = {
+ >>> 'lags_pid': [[1, 1], [3, 2], [0, 0]],
+ >>> 'alpha': 0.1,
+ >>> 'alph_s1': alph,
+ >>> 'alph_s2': alph,
+ >>> 'alph_t': alph,
+ >>> 'max_unsuc_swaps_row_parm': 60,
+ >>> 'num_reps': 63,
+ >>> 'max_iters': 1000,
+ >>> 'pid_estimator': 'SydneyPID'}
+ >>> targets = [0, 1, 2]
+ >>> sources = [[1, 2], [0, 2], [0, 1]]
+ >>> pid_analysis = BivariatePID()
+ >>> results = pid_analysis.analyse_network(settings, data, targets,
+ >>> sources)
+
+ Args:
+ settings : dict
+ parameters for estimation and statistical testing, see
+ documentation of analyse_single_target() for details, can
+ contain
+
+ - lags_pid : list of lists of ints [optional] - lags in samples
+ between sources and target (default=[[1, 1], [1, 1] ...])
+
+ data : Data instance
+ raw data for analysis
+ targets : list of int
+ index of target processes
+ sources : list of lists
+ indices of the two source processes for each target, e.g.,
+ [[0, 2], [1, 0]], must have the same length as targets
+
+ Returns:
+ ResultsPID instance
+ results of network inference, see documentation of
+ ResultsPID()
+ """
+ # Set defaults for PID estimation.
+ settings.setdefault('verbose',True)
+ settings.setdefault('lags_pid',np.array([[1,1]]*len(targets)))
+
+ # Check inputs.
+ ifnotlen(targets)==len(sources)==len(settings['lags_pid']):
+ raiseRuntimeError('Lists of targets, sources, and lags must have'
+ 'the same lengths.')
+ list_of_lags=settings['lags_pid']
+
+ # Perform PID estimation for each target individually
+ results=ResultsPID(
+ n_nodes=data.n_processes,
+ n_realisations=data.n_realisations(),
+ normalised=data.normalise)
+ fortinrange(len(targets)):
+ ifsettings['verbose']:
+ print('\n####### analysing target with index {0} from list {1}'
+ .format(t,targets))
+ settings['lags_pid']=list_of_lags[t]
+ res_single=self.analyse_single_target(
+ settings,data,targets[t],sources[t])
+ results.combine_results(res_single)
+ # Get no. realisations actually used for estimation from single target
+ # analysis.
+ results.data_properties.n_realisations=(
+ res_single.data_properties.n_realisations)
+ returnresults
+
+
[docs]defanalyse_single_target(self,settings,data,target,sources):
+ """Estimate partial information decomposition for a network node.
+
+ Estimate partial information decomposition (PID) for a target node in
+ the network.
+
+ Note:
+ For a description of the algorithm and the method see references in
+ the class and estimator docstrings.
+
+ Example:
+
+ >>> n = 20
+ >>> alph = 2
+ >>> x = np.random.randint(0, alph, n)
+ >>> y = np.random.randint(0, alph, n)
+ >>> z = np.logical_xor(x, y).astype(int)
+ >>> data = Data(np.vstack((x, y, z)), 'ps', normalise=False)
+ >>> settings = {
+ >>> 'alpha': 0.1,
+ >>> 'alph_s1': alph,
+ >>> 'alph_s2': alph,
+ >>> 'alph_t': alph,
+ >>> 'max_unsuc_swaps_row_parm': 60,
+ >>> 'num_reps': 63,
+ >>> 'max_iters': 1000,
+ >>> 'pid_calc_name': 'SydneyPID',
+ >>> 'lags_pid': [2, 3]}
+ >>> pid_analysis = BivariatePID()
+ >>> results = pid_analysis.analyse_single_target(settings=settings,
+ >>> data=data,
+ >>> target=0,
+ >>> sources=[1, 2])
+
+ Args: settings : dict parameters for estimator use and statistics:
+
+ - pid_estimator : str - estimator to be used for PID estimation
+ (for estimator settings see the documentation in the
+ estimators_pid modules)
+ - lags_pid : list of ints [optional] - lags in samples between
+ sources and target (default=[1, 1])
+ - verbose : bool [optional] - toggle console output
+ (default=True)
+
+ data : Data instance
+ raw data for analysis
+ target : int
+ index of target processes
+ sources : list of ints
+ indices of the two source processes for the target
+
+ Returns: ResultsPID instance results of
+ network inference, see documentation of
+ ResultsPID()
+ """
+ # Check input and initialise values for analysis.
+ self._initialise(settings,data,target,sources)
+
+ # Estimate PID and significance.
+ self._calculate_pid(data)
+
+ # Add analyis info.
+ results=ResultsPID(
+ n_nodes=data.n_processes,
+ n_realisations=data.n_realisations(self.current_value),
+ normalised=data.normalise)
+ results._add_single_result(
+ settings=self.settings,
+ target=self.target,
+ results=self.results)
+ self._reset()
+ returnresults
+
+ def_initialise(self,settings,data,target,sources):
+ """Check input, set initial or default values for analysis settings."""
+ # Check requested PID estimator.
+ try:
+ EstimatorClass=find_estimator(settings['pid_estimator'])
+ exceptKeyError:
+ raiseRuntimeError('Estimator was not specified!')
+ self._pid_estimator=EstimatorClass(settings)
+
+ self.settings=settings.copy()
+ self.settings.setdefault('lags_pid',[1,1])
+ self.settings.setdefault('verbose',True)
+
+ # Check if provided lags are correct and work with the number of
+ # samples in the data.
+ iflen(self.settings['lags_pid'])!=2:
+ raiseRuntimeError('List of lags must have length 2.')
+ ifself.settings['lags_pid'][0]>=data.n_samples:
+ raiseRuntimeError(
+ 'Lag 1 ({0}) is larger than the number of samples in the data '
+ 'set ({1}).'.format(
+ self.settings['lags_pid'][0],data.n_samples))
+ ifself.settings['lags_pid'][1]>=data.n_samples:
+ raiseRuntimeError(
+ 'Lag 2 ({0}) is larger than the number of samples in the data '
+ 'set ({1}).'.format(
+ self.settings['lags_pid'][1],data.n_samples))
+
+ # Check if target and sources are provided correctly.
+ iftype(target)isnotint:
+ raiseRuntimeError('Target must be an integer.')
+ iflen(sources)!=2:
+ raiseRuntimeError('List of sources must have length 2.')
+ iftargetinsources:
+ raiseRuntimeError('The target ({0}) should not be in the list '
+ 'of sources ({1}).'.format(target,sources))
+
+ self.current_value=(target,max(self.settings['lags_pid']))
+ self.target=target
+ # TODO works for single vars only, change to multivariate?
+ self.sources=self._lag_to_idx([
+ (sources[0],self.settings['lags_pid'][0]),
+ (sources[1],self.settings['lags_pid'][1])])
+
+ def_calculate_pid(self,data):
+
+ # TODO Discuss how and if the following statistical testing should be
+ # included included. Remove dummy results.
+ # [orig_pid, sign_1, p_val_1,
+ # sign_2, p_val_2] = stats.unq_against_surrogates(self, data)
+ # [orig_pid, sign_shd,
+ # p_val_shd, sign_syn, p_val_syn] = stats.syn_shd_against_surrogates(
+ # self,
+ # sign_1 = sign_2 = sign_shd = sign_syn = False
+ # p_val_1 = p_val_2 = p_val_shd = p_val_syn = 1.0
+
+ target_realisations=data.get_realisations(
+ self.current_value,
+ [self.current_value])[0]
+ source_1_realisations=data.get_realisations(
+ self.current_value,
+ [self.sources[0]])[0]
+ source_2_realisations=data.get_realisations(
+ self.current_value,
+ [self.sources[1]])[0]
+ orig_pid=self._pid_estimator.estimate(
+ s1=source_1_realisations,
+ s2=source_2_realisations,
+ t=target_realisations)
+
+ ifself.settings['verbose']:
+ print('\nunq information s1: {0:.8f}, s2: {1:.8f}'.format(
+ orig_pid['unq_s1'],
+ orig_pid['unq_s2']))
+ print('shd information: {0:.8f}, syn information: {1:.8f}'.format(
+ orig_pid['shd_s1_s2'],
+ orig_pid['syn_s1_s2']))
+ self.results=orig_pid
+ self.results['source_1']=self._idx_to_lag([self.sources[0]])
+ self.results['source_2']=self._idx_to_lag([self.sources[1]])
+ self.results['selected_vars_sources']=[
+ self.results['source_1'][0],self.results['source_2'][0]]
+ self.results['current_value']=self.current_value
+ # self.results['unq_s1_sign'] = sign_1
+ # self.results['unq_s2_sign'] = sign_2
+ # self.results['unq_s1_p_val'] = p_val_1
+ # self.results['unq_s2_p_val'] = p_val_2
+ # self.results['syn_sign'] = sign_syn
+ # self.results['syn_p_val'] = p_val_syn
+ # self.results['shd_sign'] = sign_shd
+ # self.results['shd_p_val'] = p_val_shd
+
+ # TODO make mi_against_surrogates in stats more generic, such that
+ # it becomes an arbitrary permutation test where one arguemnt gets
+ # shuffled and then all arguents are passed to the provided estimator
+
+ def_reset(self):
+ """Reset instance after analysis."""
+ self.__init__()
+ delself.results
+ delself.settings
+ delself._pid_estimator
assert(len(sources)==len(targets)),('List of targets and list of sources have to have the length')
+ # Check and set defaults for checkpointing. If requested, initialise
+ # checkpointing.
+ self.settings=self._set_checkpointing_defaults(
+ settings,data,sources,targets)
+
# Perform TE estimation for each target individuallyresults=ResultsNetworkInference(n_nodes=data.n_processes,n_realisations=data.n_realisations(),
@@ -297,6 +302,12 @@
Source code for idtxl.bivariate_te
further settings (default=False) - verbose : bool [optional] - toggle console output (default=True)
+ - write_ckp : bool [optional] - enable checkpointing, writes
+ analysis state to disk every time a variable is selected;
+ resume crashed analysis using
+ network_analysis.resume_checkpoint() (default=False)
+ - filename_ckp : str [optional] - checkpoint file name (without
+ extension) (default='./idtxl_checkpoint') data : Data instance raw data for analysis
@@ -356,23 +367,22 @@
(default='psr') normalise : bool [optional] if True, data gets normalised per process (default=True)
+ seed : int [optional]
+ can be set to a fixed integer to get repetitive results on the
+ same data with multiple runs of analyses. Otherwise a random
+ seed is set as default. Attributes: data : numpy array
@@ -105,10 +109,13 @@
Source code for idtxl.data
number of samples in time normalise : bool if true, all data gets z-standardised per process
-
+ initial_state : array
+ initial state of the seed for shuffled permutations """
- def__init__(self,data=None,dim_order='psr',normalise=True):
+ def__init__(self,data=None,dim_order='psr',normalise=True,seed=None):
+ np.random.seed(seed)
+ self.initial_state=np.random.get_state()self.normalise=normaliseifdataisnotNone:self.set_data(data,dim_order)
@@ -241,6 +248,14 @@
assertsettings['tau_source']>=1,'Source tau must be >= 1'assertsettings['history_target']>=0,'Target history must be >= 0'assertsettings['history_source']>=1,'Source history must be >= 1'
- assertsettings['source_target_delay']>=0,'Source-target delay must be >= 0'
+ assertsettings['source_target_delay']>=0,(
+ 'Source-target delay must be >= 0')returnsettings
Set common estimation parameters for JIDT Kraskov-estimators. For usage of these estimators see documentation for the child classes.
+ Results are returned in nats.
+
Args: CalcClass : JAVA class JAVA class returned by jpype.JPackage
@@ -392,7 +395,6 @@
Source code for idtxl.estimators_jidt
[docs]defestimate_surrogates_analytic(self,n_perm=200,**data):"""Estimate the surrogate distribution analytically.
-
This method must be implemented because this class' is_analytic_null_estimator() method returns true
@@ -421,6 +423,8 @@
Source code for idtxl.estimators_jidt
given (is None), the function returns the mutual information between var1 and var2. See parent class for references.
+ Results are returned in nats.
+
Args: settings : dict [optional] set estimator parameters:
@@ -459,11 +463,13 @@
Source code for idtxl.estimators_jidt
assert(settings['algorithm_num']==1)or(settings['algorithm_num']==2),('Algorithm number must be 1 or 2')if(settings['algorithm_num']==1):
- CalcClass=(jp.JPackage('infodynamics.measures.continuous.kraskov').
- ConditionalMutualInfoCalculatorMultiVariateKraskov1)
+ CalcClass=(
+ jp.JPackage('infodynamics.measures.continuous.kraskov').
+ ConditionalMutualInfoCalculatorMultiVariateKraskov1)else:
- CalcClass=(jp.JPackage('infodynamics.measures.continuous.kraskov').
- ConditionalMutualInfoCalculatorMultiVariateKraskov2)
+ CalcClass=(
+ jp.JPackage('infodynamics.measures.continuous.kraskov').
+ ConditionalMutualInfoCalculatorMultiVariateKraskov2)super().__init__(CalcClass,settings)
the third. Call JIDT via jpype and use the discrete estimator. See parent class for references.
+ Results are returned in bits.
+
Args: settings : dict [optional] sets estimation parameters:
@@ -564,11 +572,16 @@
Source code for idtxl.estimators_jidt
super().__init__(settings)# Start JAVA virtual machine and create JAVA object. Add JAVA object to
- # instance, the discrete estimator requires the variable dimensions
- # upon instantiation.
+ # instanceself._start_jvm()
- self.CalcClass=(jp.JPackage('infodynamics.measures.discrete').
- ConditionalMutualInformationCalculatorDiscrete)
+ CalcClass=(jp.JPackage('infodynamics.measures.discrete').
+ ConditionalMutualInformationCalculatorDiscrete)
+ self.calc=CalcClass()
+ self.calc.setDebug(self.settings['debug'])
+
+ # Keep a reference to an MI calculator if we need to use it (memory
+ # used here is minimal, and better than recreating it each time)
+ self.mi_calc=JidtDiscreteMI(self.settings)
"""# Calculate an MI if no conditional was providedif(conditionalisNone)or(self.settings['alphc']==0):
- est=JidtDiscreteMI(self.settings)# Return value will be just the estimate if return_calc is False,# or estimate plus the JIDT MI calculator if return_calc is True:
- returnest.estimate(var1,var2,return_calc)
+ returnself.mi_calc.estimate(var1,var2,return_calc)else:assert(conditional.size!=0),'Conditional Array is empty.'
@@ -635,34 +647,33 @@
Source code for idtxl.estimators_jidt
alph2_base=int(np.power(self.settings['alph2'],var2_dim))cond_base=int(np.power(self.settings['alphc'],cond_dim))try:
- calc=self.CalcClass(alph1_base,alph2_base,cond_base)
- exceptjp.JavaException:
- # Only possible exception that can be raised here
- # (if all bases >= 2) is a Java OutOfMemoryException:
+ self.calc.initialise(alph1_base,alph2_base,cond_base)
+ except:
+ # Handles both jp.JException (JPype v0.7) and jp.JavaException
+ # (JPype < v0.7). Only possible exception that can be raised here
+ # (if all bases >= 2) is a Java OutOfMemoryException:assert(alph1_base>=2)assert(alph2_base>=2)assert(cond_base>=2)
- raiseex.JidtOutOfMemoryError('Cannot instantiate JIDT CMI '
- 'discrete estimator with alph1_base = '+str(alph1_base)+
- ', alph2_base = '+str(alph2_base)+', cond_base = '+
- str(cond_base)+'. Try re-running increasing Java heap size')
- calc.setDebug(self.settings['debug'])
- calc.initialise()
+ raiseex.JidtOutOfMemoryError(
+ 'Cannot instantiate JIDT CMI discrete estimator with '
+ 'alph1_base = {}, alph2_base = {}, cond_base = {}. Try '
+ 're-running increasing Java heap size.'.format(
+ alph1_base,alph2_base,cond_base))# Unfortunately no faster way to pass numpy arrays in than this list# conversion
- calc.addObservations(jp.JArray(jp.JInt,1)(var1.tolist()),
- jp.JArray(jp.JInt,1)(var2.tolist()),
- jp.JArray(jp.JInt,1)(conditional.tolist()))
+ self.calc.addObservations(jp.JArray(jp.JInt,1)(var1.tolist()),
+ jp.JArray(jp.JInt,1)(var2.tolist()),
+ jp.JArray(jp.JInt,1)(conditional.tolist()))ifself.settings['local_values']:
- result=np.array(calc.computeLocalFromPreviousObservations(
+ result=np.array(self.calc.computeLocalFromPreviousObservations(jp.JArray(jp.JInt,1)(var1.tolist()),jp.JArray(jp.JInt,1)(var2.tolist()),
- jp.JArray(jp.JInt,1)(conditional.tolist())
- ))
+ jp.JArray(jp.JInt,1)(conditional.tolist())))else:
- result=calc.computeAverageLocalOfObservations()
+ result=self.calc.computeAverageLocalOfObservations()ifreturn_calc:
- return(result,calc)
+ return(result,self.calc)else:returnresult
@@ -700,6 +711,8 @@
Source code for idtxl.estimators_jidt
Calculate the mutual information (MI) between two variables. Call JIDT via jpype and use the discrete estimator. See parent class for references.
+ Results are returned in bits.
+
Args: settings : dict [optional] sets estimation parameters:
@@ -741,11 +754,12 @@
Source code for idtxl.estimators_jidt
self.settings.setdefault('alph2',int(2))# Start JAVA virtual machine and create JAVA object. Add JAVA object to
- # instance, the discrete estimator requires the variable dimensions
- # upon instantiation.
+ # instance.self._start_jvm()
- self.CalcClass=(jp.JPackage('infodynamics.measures.discrete').
- MutualInformationCalculatorDiscrete)
+ CalcClass=(jp.JPackage('infodynamics.measures.discrete').
+ MutualInformationCalculatorDiscrete)
+ self.calc=CalcClass()
+ self.calc.setDebug(self.settings['debug'])
base_for_var1=int(np.power(self.settings['alph1'],var1_dim))base_for_var2=int(np.power(self.settings['alph2'],var2_dim))try:
- calc=self.CalcClass(base_for_var1,base_for_var2,
- self.settings['lag_mi'])
- exceptjp.JavaException:
- # Only possible exception that can be raised here
- # (if base_for_var* >= 2) is a Java OutOfMemoryException:
+ self.calc.initialise(base_for_var1,base_for_var2,
+ self.settings['lag_mi'])
+ except:
+ # Handles both jp.JException (JPype v0.7) and jp.JavaException
+ # (JPype < v0.7). Only possible exception that can be raised here
+ # (if base_for_var* >= 2) is a Java OutOfMemoryException:assert(base_for_var1>=2)assert(base_for_var2>=2)
- raiseex.JidtOutOfMemoryError('Cannot instantiate JIDT MI '
- 'discrete estimator with bases = '+str(base_for_var1)+
- ' and '+str(base_for_var2)+
- '. Try re-running increasing Java heap size')
- calc.setDebug(self.settings['debug'])
- calc.initialise()
+ raiseex.JidtOutOfMemoryError(
+ 'Cannot instantiate JIDT MI discrete estimator with bases = {}'
+ ' and {}. Try re-running increasing Java heap size.'.format(
+ base_for_var1,base_for_var2))# Unfortunately no faster way to pass numpy arrays in than this list# conversion
- calc.addObservations(jp.JArray(jp.JInt,1)(var1.tolist()),
- jp.JArray(jp.JInt,1)(var2.tolist()))
+ self.calc.addObservations(jp.JArray(jp.JInt,1)(var1.tolist()),
+ jp.JArray(jp.JInt,1)(var2.tolist()))ifself.settings['local_values']:
- result=np.array(calc.computeLocalFromPreviousObservations(
+ result=np.array(self.calc.computeLocalFromPreviousObservations(jp.JArray(jp.JInt,1)(var1.tolist()),jp.JArray(jp.JInt,1)(var2.tolist())))else:
- result=calc.computeAverageLocalOfObservations()
+ result=self.calc.computeAverageLocalOfObservations()ifreturn_calc:
- return(result,calc)
+ return(result,self.calc)else:returnresult
@@ -852,6 +865,8 @@
Source code for idtxl.estimators_jidt
Calculate the mutual information between two variables. Call JIDT via jpype and use the Kraskov 1 estimator. See parent class for references.
+ Results are returned in nats.
+
Args: settings : dict [optional] sets estimation parameters:
@@ -894,10 +909,10 @@
Source code for idtxl.estimators_jidt
'Algorithm number must be 1 or 2')if(settings['algorithm_num']==1):CalcClass=(jp.JPackage('infodynamics.measures.continuous.kraskov').
- MutualInfoCalculatorMultiVariateKraskov1)
+ MutualInfoCalculatorMultiVariateKraskov1)else:CalcClass=(jp.JPackage('infodynamics.measures.continuous.kraskov').
- MutualInfoCalculatorMultiVariateKraskov2)
+ MutualInfoCalculatorMultiVariateKraskov2)super().__init__(CalcClass,settings)# Get lag and shift second variable to account for a lag if requested
@@ -932,7 +947,7 @@
tau describes the embedding delay, i.e., the spacing between every two samples from the processes' past.
- See parent class for references.
+ See parent class for references. Results are returned in nats. Args: settings : dict
@@ -1037,6 +1052,8 @@
Source code for idtxl.estimators_jidt
Calculate the active information storage (AIS) for one process. Call JIDT via jpype and use the discrete estimator. See parent class for references.
+ Results are returned in bits.
+
Args: settings : dict set estimator parameters:
@@ -1076,12 +1093,14 @@
Source code for idtxl.estimators_jidt
pass# Do nothing and use the default for alph set belowsettings.setdefault('alph',int(2))assertsettings['alph']>=2,'Number of bins must be >= 2'
+ super().__init__(settings)# Start JAVA virtual machine and create JAVA object.self._start_jvm()
- self.CalcClass=(jp.JPackage('infodynamics.measures.discrete').
- ActiveInformationCalculatorDiscrete)
- super().__init__(settings)
+ CalcClass=(jp.JPackage('infodynamics.measures.discrete').
+ ActiveInformationCalculatorDiscrete)
+ self.calc=CalcClass()
+ self.calc.setDebug(self.settings['debug'])
[docs]defestimate(self,process,return_calc=False):"""Estimate active information storage.
@@ -1131,26 +1150,28 @@
Source code for idtxl.estimators_jidt
# And finally make the AIS calculation:try:
- calc=self.CalcClass(self.settings['alph'],self.settings['history'])
- exceptjp.JavaException:
- # Only possible exception that can be raised here
- # (if self.settings['alph'] >= 2) is a Java OutOfMemoryException:
+ self.calc.initialise(
+ self.settings['alph'],self.settings['history'])
+ except:
+ # Handles both jp.JException (JPype v0.7) and jp.JavaException
+ # (JPype < v0.7). Only possible exception that can be raised here
+ # (if self.settings['alph'] >= 2) is a Java OutOfMemoryException:assert(self.settings['alph']>=2)
- raiseex.JidtOutOfMemoryError('Cannot instantiate JIDT AIS '
- 'discrete estimator with alph = '+str(self.settings['alph'])+
- ' and history = '+str(self.settings['history'])+
- '. Try re-running increasing Java heap size')
- calc.initialise()
+ raiseex.JidtOutOfMemoryError(
+ 'Cannot instantiate JIDT AIS discrete estimator with alph = {}'
+ ' and history = {}. Try re-running increasing Java heap '
+ 'size.'.format(
+ self.settings['alph'],self.settings['history']))# Unfortunately no faster way to pass numpy arrays in than this list# conversion
- calc.addObservations(jp.JArray(jp.JInt,1)(process.tolist()))
+ self.calc.addObservations(jp.JArray(jp.JInt,1)(process.tolist()))ifself.settings['local_values']:
- result=np.array(calc.computeLocalFromPreviousObservations(
+ result=np.array(self.calc.computeLocalFromPreviousObservations(jp.JArray(jp.JInt,1)(process.tolist())))else:
- result=calc.computeAverageLocalOfObservations()
+ result=self.calc.computeAverageLocalOfObservations()ifreturn_calc:
- return(result,calc)
+ return(result,self.calc)else:returnresult
@@ -1190,7 +1211,7 @@
Source code for idtxl.estimators_jidt
tau describes the embedding delay, i.e., the spacing between every two samples from the processes' past.
- See parent class for references.
+ See parent class for references.Results are returned in nats. Args: settings : dict
@@ -1260,6 +1281,8 @@
Source code for idtxl.estimators_jidt
Calculate the mutual information between two variables. Call JIDT via jpype and use the Gaussian estimator. See parent class for references.
+ Results are returned in nats.
+
Args: settings : dict [optional] sets estimation parameters:
@@ -1317,7 +1340,7 @@
If no conditional is given (is None), the function returns the mutual information between var1 and var2.
- See parent class for references.
+ See parent class for references. Results are returned in nats. Args: settings : dict [optional]
@@ -1403,7 +1426,7 @@
tau descrices the embedding delay, i.e., the spacing between every two samples from the processes' past.
- See parent class for references.
+ See parent class for references. Results are returned in nats. Args: settings : dict
@@ -1495,7 +1518,6 @@
[docs]defestimate(self,source,target):"""Estimate transfer entropy from a source to a target variable.
@@ -1540,6 +1562,8 @@
Source code for idtxl.estimators_jidt
state and the target's current value, conditional on the target's past. See parent class for references.
+ Results are returned in bits.
+
Args: settings : dict sets estimation parameters:
@@ -1593,14 +1617,18 @@
Source code for idtxl.estimators_jidt
'Num discrete levels for source has to be an integer.')asserttype(settings['alph2'])isint,('Num discrete levels for target has to be an integer.')
- assertsettings['alph1']>=2,'Num discrete levels for source must be >= 2'
- assertsettings['alph2']>=2,'Num discrete levels for target must be >= 2'
+ assertsettings['alph1']>=2,(
+ 'Num discrete levels for source must be >= 2')
+ assertsettings['alph2']>=2,(
+ 'Num discrete levels for target must be >= 2')super().__init__(settings)# Start JAVA virtual machine and create JAVA object.self._start_jvm()
- self.CalcClass=(jp.JPackage('infodynamics.measures.discrete').
- TransferEntropyCalculatorDiscrete)
+ CalcClass=(jp.JPackage('infodynamics.measures.discrete').
+ TransferEntropyCalculatorDiscrete)
+ self.calc=CalcClass()
+ self.calc.setDebug(self.settings['debug'])
[docs]defestimate(self,source,target,return_calc=False):"""Estimate transfer entropy from a source to a target variable.
@@ -1639,34 +1667,37 @@
Source code for idtxl.estimators_jidt
# And finally make the TE calculation:max_base=max(self.settings['alph1'],self.settings['alph2'])try:
- calc=self.CalcClass(max_base,
- self.settings['history_target'],
- self.settings['tau_target'],
- self.settings['history_source'],
- self.settings['tau_source'],
- self.settings['source_target_delay'])
- exceptjp.JavaException:
- # Only possible exception that can be raised here
- # (if max_base >= 2) is a Java OutOfMemoryException:
+ self.calc.initialise(max_base,
+ self.settings['history_target'],
+ self.settings['tau_target'],
+ self.settings['history_source'],
+ self.settings['tau_source'],
+ self.settings['source_target_delay'])
+ except:
+ # Handles both jp.JException (JPype v0.7) and jp.JavaException
+ # (JPype < v0.7). Only possible exception that can be raised here
+ # (if max_base >= 2) is a Java OutOfMemoryException:assert(max_base>=2)
- raiseex.JidtOutOfMemoryError('Cannot instantiate JIDT TE '
- 'discrete estimator with max_base = '+str(max_base)+
- ' and history_target = '+str(self.settings['history_target'])+
- ' and history_source = '+str(self.settings['history_source'])+
- '. Try re-running increasing Java heap size')
- calc.initialise()
+ raiseex.JidtOutOfMemoryError(
+ 'Cannot instantiate JIDT TE discrete estimator with max_base ='
+ ' {} and history_target = {} and history_source = {}. Try '
+ 're-running increasing Java heap size.'.format(
+ max_base,
+ self.settings['history_target'],
+ self.settings['history_source']))# Unfortunately no faster way to pass numpy arrays in than this list# conversion
- calc.addObservations(jp.JArray(jp.JInt,1)(source.tolist()),
- jp.JArray(jp.JInt,1)(target.tolist()))
+ self.calc.addObservations(
+ jp.JArray(jp.JInt,1)(source.tolist()),
+ jp.JArray(jp.JInt,1)(target.tolist()))ifself.settings['local_values']:
- result=np.array(calc.computeLocalFromPreviousObservations(
+ result=np.array(self.calc.computeLocalFromPreviousObservations(jp.JArray(jp.JInt,1)(source.tolist()),jp.JArray(jp.JInt,1)(target.tolist())))else:
- result=calc.computeAverageLocalOfObservations()
+ result=self.calc.computeAverageLocalOfObservations()ifreturn_calc:
- return(result,calc)
+ return(result,self.calc)else:returnresult
@@ -1709,7 +1740,7 @@
Source code for idtxl.estimators_jidt
tau descrices the embedding delay, i.e., the spacing between every two samples from the processes' past.
- See parent class for references.
+ See parent class for references. Results are returned in nats. Args: settings : dict
@@ -1814,23 +1845,22 @@
+"""Multivariate Partical information decomposition for discrete random variables.
+
+This module provides an estimator for multivariate partial information
+decomposition as proposed in
+
+- Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure
+ for shared information. 1- 27 Retrieved from
+ http://arxiv.org/abs/2002.03356
+"""
+importnumpyasnp
+from.importlatticesaslt
+from.importpid_goettingen
+from.estimatorimportEstimator
+from.estimators_pidimport_join_variables
+
+# TODO add support for multivariate estimation for Tartu and Sydney estimator
+
+
+
[docs]classSxPID(Estimator):
+ """Estimate partial information decomposition for multiple inputs.
+
+ Implementation of the multivariate partial information decomposition (PID)
+ estimator for discrete data with (up to 4 inputs) and one output. The
+ estimator finds shared information, unique information and synergistic
+ information between the multiple inputs s1, s2, ..., sn with respect to the
+ output t for each realization (t, s1, ..., sn) and then average them
+ according to their distribution weights p(t, s1, ..., sn). Both the
+ pointwise (on the realization level) PID and the averaged PID are returned
+ (see the 'return' of 'estimate()').
+
+ The algorithm uses recursion to compute the partial information
+ decomposition.
+
+ References:
+
+ - Makkeh, A. & Wibral, M. (2020). A differentiable pointwise partial
+ Information Decomposition estimator. https://github.com/Abzinger/SxPID.
+
+ Args:
+ settings : dict
+ estimation parameters (with default parameters)
+
+ - verbose : bool [optional] - print output to console
+ (default=False)
+ """
+
+ def__init__(self,settings):
+ # get estimation parameters
+ self.settings=settings.copy()
+ self.settings.setdefault('verbose',False)
+
+
[docs]defestimate(self,s,t):
+ """
+ Args:
+ s : list of numpy arrays
+ 1D arrays containing realizations of a discrete random variable
+ t : numpy array
+ 1D array containing realizations of a discrete random variable
+
+ Returns:
+ dict of dict
+ {
+ 'ptw' -> { realization -> {alpha -> [float, float, float]} }
+
+ 'avg' -> {alpha -> [float, float, float]}
+ }
+ where the list of floats is ordered
+ [informative, misinformative, informative - misinformative]
+ ptw stands for pointwise decomposition
+ avg stands for average decomposition
+ """
+ s,t,self.settings=_check_input(s,t,self.settings)
+ pdf=_get_pdf_dict(s,t)
+
+ # Read lattices from a file
+ # Stored as {
+ # n -> [{alpha -> children}, (alpha_1,...) ]
+ # }
+ # children is a list of tuples
+ lattices=lt.lattices
+ num_source_vars=len(s)
+ retval_ptw,retval_avg=pid_goettingen.pid(
+ num_source_vars,
+ pdf_orig=pdf,
+ chld=lattices[num_source_vars][0],
+ achain=lattices[num_source_vars][1],
+ printing=self.settings['verbose'])
+
+ # TODO AskM: Trivariate: does it make sense to name the alphas
+ # for example shared_syn_s1_s2__syn_s1_s3 ?
+ results={
+ 'ptw':retval_ptw,
+ 'avg':retval_avg,
+ }
+ returnresults
+
+
+def_get_pdf_dict(s,t):
+ """"Write probability mass function estimated via counting to a dict."""
+ # Create dictionary with probability mass function
+ counts=dict()
+ n_samples=s[0].shape[0]
+
+ # Count occurences.
+ foriinrange(n_samples):
+ key=tuple([s[j][i]forjinrange(len(s))])+(t[i],)
+ ifkeyincounts.keys():
+ counts[key]+=1
+ else:
+ counts[key]=1
+
+ # Create PMF from counts.
+ pmf=dict()
+ forxyz,cincounts.items():
+ pmf[xyz]=c/float(n_samples)
+ returnpmf
+
+
+def_check_input(s,t,settings):
+ """Check input to PID estimators."""
+ # Check if inputs are numpy arrays.
+ iftype(t)!=np.ndarray:
+ raiseTypeError('Input t must be a numpy array.')
+ foriinrange(len(s)):
+ iftype(s[i])!=np.ndarray:
+ raiseTypeError('All inputs s{0} must be numpy arrays.'.format(i+1))
+
+ # In general, IDTxl expects 2D inputs because JIDT/JPYPE only accepts those
+ # and we have a multivariate approach, i.e., a vector is a special case of
+ # 2D-data. The PID estimators on the other hand, expect 1D data. Squeeze 2D
+ # arrays if the dimension of the second axis is 1. Otherwise combine
+ # multivariate sources into a single variable for estimation.
+ foriinrange(len(s)):
+ ifs[i].ndim!=1:
+ ifs[i].shape[1]==1:
+ s[i]=np.squeeze(s[i])
+ elifs[i].ndim==2ands[i].shape[1]>1:
+ si_joint=s[i][:,0]
+ alph_new=len(np.unique(s[i][:,0]))
+ forcolinrange(1,s[i].shape[1]):
+ alph_col=len(np.unique(s[i][:,col]))
+ si_joint,alph_new=_join_variables(
+ si_joint,s[i][:,col],alph_new,alph_col)
+ settings['alph_s'+str(i+1)]=alph_new
+ else:
+ raiseValueError('Input source {0} s{0} has to be a 1D or 2D '
+ 'numpy array.'.format(i+1))
+
+ ift.ndim!=1:
+ ift.shape[1]==1:
+ t=np.squeeze(t)
+ else:# For now we only allow 1D-targets
+ raiseValueError('Input target t has to be a vector '
+ '(t.shape[1]=1).')
+
+ # Check types of remaining inputs.
+ iftype(settings)!=dict:
+ raiseTypeError('The settings argument should be a dictionary.')
+ foriinrange(len(s)):
+ ifnotissubclass(s[i].dtype.type,np.integer):
+ raiseTypeError('Input s{0} (source {0}) must be an integer numpy '
+ 'array.'.format(i+1))
+ # ^ for
+ ifnotissubclass(t.dtype.type,np.integer):
+ raiseTypeError('Input t (target) must be an integer numpy array.')
+
+ # Check if variables have equal length.
+ foriinrange(len(s)):
+ iflen(t)!=len(s[i]):
+ raiseValueError('Number of samples s and t must be equal')
+
+ returns,t,settings
+
' it using pip or the package manager to use ''OpenCL-powered CMI estimation.')
+logger=logging.getLogger(__name__)
+C=1024**2
+
[docs]classOpenCLKraskov(Estimator):"""Abstract class for implementation of OpenCL estimators.
@@ -98,6 +103,8 @@
Source code for idtxl.estimators_opencl
in KNN and range searches (default=0) - noise_level : float [optional] - random noise added to the data (default=1e-8)
+ - padding : bool [optional] - pad data to a length that is a
+ multiple of 1024, workaround for a - debug : bool [optional] - calculate intermediate results, i.e. neighbour counts from range searches and KNN distances, print debug output to console (default=False)
@@ -115,6 +122,7 @@
Calculate the mutual information (MI) between two variables using OpenCL GPU-code. See parent class for references.
+ Results are returned in nats.
+
Args: settings : dict [optional] set estimator parameters:
@@ -267,12 +276,10 @@
Source code for idtxl.estimators_opencl
max_chunks_per_run=np.floor(max_mem/mem_chunk).astype(int)chunks_per_run=min(max_chunks_per_run,n_chunks)
- ifself.settings['debug']:
- print('Memory per chunk: {0:.5f} MB, GPU global memory: {1} MB, '
- 'chunks per run: {2}.'.format(mem_chunk/1024/1024,
- max_mem/1024/1024,
- chunks_per_run))
-
+ logger.debug(
+ 'Memory per chunk: {0:.5f} MB, GPU global memory: {1} MB, chunks '
+ 'per run: {2}.'.format(
+ mem_chunk/C,max_mem/C,chunks_per_run))ifmem_chunk>max_mem:raiseRuntimeError('Size of single chunk exceeds GPU global ''memory.')
@@ -329,9 +336,6 @@
Source code for idtxl.estimators_opencl
"""# Prepare data and add noise: check if variable realisations are passed# as 1D or 2D arrays and have equal no. observations.
- ifself.settings['debug']:
- print('var1 shape: {0}, {1}, n_chunks: {2}'.format(
- var1.shape[0],var1.shape[1],n_chunks))var1=self._ensure_two_dim_input(var1)var2=self._ensure_two_dim_input(var2)assertvar1.shape[0]==var2.shape[0]
@@ -339,25 +343,40 @@
Source code for idtxl.estimators_opencl
self._check_number_of_points(var1.shape[0])signallength=var1.shape[0]chunklength=signallength//n_chunks
+ assertsignallength%n_chunks==0var1dim=var1.shape[1]var2dim=var2.shape[1]pointdim=var1dim+var2dim
- # Pad time series to make GPU memory regions a multiple of 1024
- pad_target=1024
- pad_size=(int(np.ceil(signallength/pad_target))*pad_target-
- signallength)
- pad_var1=np.vstack(
- [var1,999999+0.1*np.random.rand(pad_size,var1dim)])
- pad_var2=np.vstack(
- [var2,999999+0.1*np.random.rand(pad_size,var2dim)])
- pointset=np.hstack((pad_var1,pad_var2)).T.copy()
- signallength_padded=signallength+pad_size
- ifself.settings['noise_level']>0:
- pointset+=np.random.normal(scale=self.settings['noise_level'],
- size=pointset.shape)
+ # prepare for the padding
+ signallength_orig=signallength# used for clarity at present
+
+ ifself.settings['padding']:
+ # Pad time series to make GPU memory regions a multiple of 1024
+ # This value of 1024 should be replaced by
+ # self.devices[self.settings['gpuid']].CL_DEVICE_MEM_BASE_ADDR_ALIGN
+ # or something similar, as professional cards are known to have
+ # base adress alignment of 4096 sometimes
+ pad_target=4096
+ pad_size=(int(np.ceil(signallength/pad_target))*pad_target-
+ signallength)
+ pad_var1=np.vstack(
+ [var1,999999+0.1*np.random.rand(pad_size,var1dim)])
+ pad_var2=np.vstack(
+ [var2,999999+0.1*np.random.rand(pad_size,var2dim)])
+ pointset=np.hstack((pad_var1,pad_var2)).T.copy()
+ signallength_padded=signallength+pad_size
+ else:
+ pad_size=0
+ pointset=np.hstack((var1,var2)).T.copy()
+ signallength_padded=signallength
+
ifnotpointset.dtype==np.float32:pointset=pointset.astype(np.float32)
+ ifself.settings['noise_level']>0:
+ pointset+=np.random.normal(
+ scale=self.settings['noise_level'],
+ size=pointset.shape).astype(np.float32)ifself.settings['debug']:# Print memory requirements after padding
@@ -367,8 +386,12 @@
d_distances.release()d_npointsrange_x.release()d_npointsrange_y.release()
+ d_var1.release()
+ d_var2.release()
+ d_vecradius.release()# Calculate and sum digammasifself.settings['local_values']:
@@ -459,9 +507,13 @@
Source code for idtxl.estimators_opencl
digamma(count_var1[c*chunklength:(c+1)*chunklength]+1)+digamma(count_var2[c*chunklength:(c+1)*chunklength]+1)))mi_array[c]=mi
+ assertsignallength_orig==(c+1)*chunklength,'Original signal length does not match no. processed points.'ifself.settings['debug']:
- returnmi_array,distances,count_var1,count_var2
+ return(mi_array,
+ distances[:signallength_orig],
+ count_var1[:signallength_orig],
+ count_var2[:signallength_orig])else:returnmi_array
@@ -474,6 +526,8 @@
Source code for idtxl.estimators_opencl
returns the mutual information between var1 and var2. See parent class for references.
+ Results are returned in nats.
+
Args: settings : dict [optional] set estimator parameters:
@@ -557,11 +611,10 @@
Source code for idtxl.estimators_opencl
max_chunks_per_run=np.floor(max_mem/mem_chunk).astype(int)chunks_per_run=min(max_chunks_per_run,n_chunks)
- ifself.settings['debug']:
- print('Memory per chunk: {0:.5f} MB, GPU global memory: {1} MB, '
- 'chunks per run: {2}.'.format(mem_chunk/1024/1024,
- max_mem/1024/1024,
- chunks_per_run))
+ logger.debug(
+ 'Memory per chunk: {0:.5f} MB, GPU global memory: {1} MB, chunks '
+ 'per run: {2}.'.format(
+ mem_chunk/C,max_mem/C,chunks_per_run))ifmem_chunk>max_mem:raiseRuntimeError('Size of single chunk exceeds GPU global ''memory.')
@@ -645,23 +698,35 @@
Source code for idtxl.estimators_opencl
conddim=conditional.shape[1]pointdim=var1dim+var2dim+conddim
- # Pad time series to make GPU memory regions a multiple of 1024
- pad_target=1024
- pad_size=(int(np.ceil(signallength/pad_target))*pad_target-
- signallength)
- pad_var1=np.vstack(
- [var1,999999+0.1*np.random.rand(pad_size,var1dim)])
- pad_var2=np.vstack(
- [var2,999999+0.1*np.random.rand(pad_size,var2dim)])
- pad_conditional=np.vstack(
- [conditional,999999+0.1*np.random.rand(pad_size,conddim)])
- pointset=np.hstack((pad_var1,pad_conditional,pad_var2)).T.copy()
- signallength_padded=signallength+pad_size
- ifself.settings['noise_level']>0:
- pointset+=np.random.normal(scale=self.settings['noise_level'],
- size=pointset.shape)
+ # prepare padding
+ signallength_orig=signallength
+
+ ifself.settings['padding']:
+ # Pad time series to make GPU memory regions a multiple of 4096
+ # 4096 is the largestknown value for opencl subbuffer alignment targets
+ # but see comment in MI estimator above
+ pad_target=4096
+ pad_size=(int(np.ceil(signallength/pad_target))*pad_target-
+ signallength)
+ pad_var1=np.vstack(
+ [var1,999999+0.1*np.random.rand(pad_size,var1dim)])
+ pad_var2=np.vstack(
+ [var2,999999+0.1*np.random.rand(pad_size,var2dim)])
+ pad_conditional=np.vstack(
+ [conditional,999999+0.1*np.random.rand(pad_size,conddim)])
+ pointset=np.hstack((pad_var1,pad_conditional,pad_var2)).T.copy()
+ signallength_padded=signallength+pad_size
+ else:
+ pad_size=0
+ pointset=np.hstack((var1,conditional,var2)).T.copy()
+ signallength_padded=signallength
+
ifnotpointset.dtype==np.float32:pointset=pointset.astype(np.float32)
+ ifself.settings['noise_level']>0:
+ pointset+=np.random.normal(
+ scale=self.settings['noise_level'],
+ size=pointset.shape).astype(np.float32)ifself.settings['debug']:# Print memory requirements after padding
@@ -671,8 +736,9 @@
"""s1,s2,t,self.settings=_check_input(s1,s2,t,self.settings)
- # Check if float128 is supported by the architecture
+ # Check if longdouble is supported by the architecturetry:
- np.float128()
+ np.longdouble()exceptAttributeErroraserr:
- if"'module' object has no attribute 'float128'"==err.args[0]:
+ if"'module' object has no attribute 'longdouble'"==err.args[0]:raiseRuntimeError(
- 'This system doesn''t seem to support float128 '
+ 'This system doesn''t seem to support longdouble ''(requirement for using the Sydney PID-estimator.')else:raise
@@ -230,19 +230,19 @@
Source code for idtxl.estimators_pid
joint_t_s1_s2_count[np.nonzero(joint_t_s1_s2_count)])# Fixed probabilities
- t_prob=np.divide(t_count,num_samples).astype('float128')
- s1_prob=np.divide(s1_count,num_samples).astype('float128')
- s2_prob=np.divide(s2_count,num_samples).astype('float128')
+ t_prob=np.divide(t_count,num_samples).astype('longdouble')
+ s1_prob=np.divide(s1_count,num_samples).astype('longdouble')
+ s2_prob=np.divide(s2_count,num_samples).astype('longdouble')joint_t_s1_prob=np.divide(joint_t_s1_count,
- num_samples).astype('float128')
+ num_samples).astype('longdouble')joint_t_s2_prob=np.divide(joint_t_s2_count,
- num_samples).astype('float128')
+ num_samples).astype('longdouble')# Variable probabilitiesjoint_s1_s2_prob=np.divide(joint_s1_s2_count,
- num_samples).astype('float128')
+ num_samples).astype('longdouble')joint_t_s1_s2_prob=np.divide(joint_t_s1_s2_count,
- num_samples).astype('float128')
+ num_samples).astype('longdouble')max_prob=np.max(joint_t_s1_s2_prob[np.nonzero(joint_t_s1_s2_prob)])# # make copies of the variable probabilities for independent second
@@ -297,8 +297,8 @@
Source code for idtxl.estimators_pid
# Replication loopforrepinreps:prob_inc=np.multiply(
- np.float128(max_prob),
- np.divide(np.float128(1),np.float128(rep)))
+ np.longdouble(max_prob),
+ np.divide(np.longdouble(1),np.longdouble(rep)))# Want to store number of succesive unsuccessful swapsunsuccessful_swaps_row=0# SWAP LOOP
@@ -414,7 +414,7 @@
+"""Import external file formats into IDTxl.
+
+Provide functions to import the following into IDTxl:
+
+ - mat-files (version>7.3, hdf5)
+ - FieldTrip-style mat-files (version>7.3, hdf5)
+
+Matlab supports hdf5 only for files saved as version 7.3 or higher:
+https://au.mathworks.com/help/matlab/ref/save.html#inputarg_version
+
+Creates a numpy array usable as input to IDTxl.
+
+Methods:
+ ft_trial_2_numpyarray(file_name, ft_struct_name)
+ matarray2idtxlconverter(file_name, array_name, order) = takes a file_name,
+ the name of the array variable (array_name) inside,
+ and the order of sensor axis, time axisand (CHECK THIS!!)
+ repetition axis (as a list)
+
+Note:
+ Written for Python 3.4+
+
+Created on Wed Mar 19 12:34:36 2014
+
+@author: Michael Wibral
+"""
+importh5py
+importnumpyasnp
+fromscipy.ioimportloadmat
+fromidtxl.dataimportData
+
+VERBOSE=False
+
+
+
[docs]defimport_fieldtrip(file_name,ft_struct_name,file_version,normalise=True):
+ """Convert FieldTrip-style MATLAB-file into an IDTxl Data object.
+
+ Import a MATLAB structure with fields "trial" (data), "label" (channel
+ labels), "time" (time stamps for data samples), and "fsample" (sampling
+ rate). This structure is the standard file format in the MATLAB toolbox
+ FieldTrip and commonly use to represent neurophysiological data (see also
+ http://www.fieldtriptoolbox.org/reference/ft_datatype_raw). The data is
+ returned as a IDTxl Data() object.
+
+ The structure is assumed to be saved as a matlab hdf5 file ("-v7.3' or
+ higher, .mat) with a SINGLE FieldTrip data structure inside.
+
+ Args:
+ file_name : string
+ full (matlab) file_name on disk
+ ft_struct_name : string
+ variable name of the MATLAB structure that is in FieldTrip format
+ (autodetect will hopefully be possible later ...)
+ file_version : string
+ version of the file, e.g. 'v7.3' for MATLAB's 7.3 format
+ normalise : bool [optional]
+ normalise data after import (default=True)
+
+ Returns:
+ Data() instance
+ instance of IDTxl Data object, containing data from the 'trial'
+ field
+ list of strings
+ list of channel labels, corresponding to the 'label' field
+ numpy array
+ time stamps for samples, corresponding to one entry in the 'time'
+ field
+ int
+ sampling rate, corresponding to the 'fsample' field
+
+ @author: Michael Wibral
+ """
+ iffile_version!="v7.3":
+ raiseRuntimeError('At present only m-files in format 7.3 are '
+ 'supported, please consider reopening and resaving '
+ 'your m-file in that version.')
+ # TODO we could write a fallback option using numpy's loadmat?
+
+ print('Creating Python dictionary from FT data structure: {0}'
+ .format(ft_struct_name))
+ trial_data=_ft_import_trial(file_name,ft_struct_name)
+ label=_ft_import_label(file_name,ft_struct_name)
+ fsample=_ft_fsample_2_float(file_name,ft_struct_name)
+ timestamps=_ft_import_time(file_name,ft_struct_name)
+
+ dat=Data(data=trial_data,dim_order='spr',normalise=normalise)
+ returndat,label,timestamps,fsample
+
+
+def_ft_import_trial(file_name,ft_struct_name):
+ """Import FieldTrip trial data into Python."""
+ ft_file=h5py.File(file_name)
+ ft_struct=ft_file[ft_struct_name]# TODO: ft_struct_name = automagic...
+
+ # Get the trial cells that contain the references (pointers) to the data
+ # we need. Then get the data from matrices in cells of a 1 x numtrials cell
+ # array in the original FieldTrip structure.
+ trial=ft_struct['trial']
+
+ # Get the trial cells that contain the references (pointers) to the data
+ # we need. Then get the data from matrices in cells of a 1 x numtrials cell
+ # array in the original FieldTrip structure.
+ trial=ft_struct['trial']
+
+ # Allocate memory to hold actual data, read shape of first trial to know
+ # the data size.
+ trial_data_tmp=np.array(ft_file[trial[0][0]])# get data from 1st trial
+ print('Found data with first dimension: {0}, and second: {1}'
+ .format(trial_data_tmp.shape[0],trial_data_tmp.shape[1]))
+ geometry=trial_data_tmp.shape+(trial.shape[0],)
+ trial_data=np.empty(geometry)
+
+ # Get actual data from h5py structure.
+ forttinrange(0,trial.shape[0]):
+ trialref=trial[tt][0]# get trial reference
+ trial_data[:,:,tt]=np.array(ft_file[trialref])# get data
+
+ ft_file.close()
+ returntrial_data
+
+
+def_ft_import_label(file_name,ft_struct_name):
+ """Import FieldTrip labels into Python."""
+ # for details of the data handling see comments in _ft_import_trial
+ ft_file=h5py.File(file_name)
+ ft_struct=ft_file[ft_struct_name]
+ ft_label=ft_struct['label']
+
+ ifVERBOSE:
+ print('Converting FT labels to python list of strings')
+
+ label=[]
+ forllinrange(0,ft_label.shape[0]):
+ # There is only one item in labelref, but we have to index it.
+ # Matlab has character arrays that are read as bytes in Python 3.
+ # Here, map maps the stuff in labeltmp to characters and "".
+ # makes it into a real Python string.
+ labelref=ft_label[ll]
+ labeltmp=ft_file[labelref[0]]
+ strlabeltmp="".join(map(chr,labeltmp[0:]))
+ label.append(strlabeltmp)
+
+ ft_file.close()
+ returnlabel
+
+
+def_ft_import_time(file_name,ft_struct_name):
+ """Import FieldTrip time stamps into Python."""
+ # for details of the data handling see comments in ft_trial_2_numpyarray
+ ft_file=h5py.File(file_name)
+ ft_struct=ft_file[ft_struct_name]
+ ft_time=ft_struct['time']
+ ifVERBOSE:
+ print('Converting FT time cell array to numpy array')
+
+ np_timeaxis_tmp=np.array(ft_file[ft_time[0][0]])
+ geometry=np_timeaxis_tmp.shape+(ft_time.shape[0],)
+ timestamps=np.empty(geometry)
+ forttinrange(0,ft_time.shape[0]):
+ timeref=ft_time[tt][0]
+ timestamps[:,:,tt]=np.array(ft_file[timeref])
+ ft_file.close()
+ returntimestamps
+
+
+def_ft_fsample_2_float(file_name,ft_struct_name):
+ ft_file=h5py.File(file_name)
+ ft_struct=ft_file[ft_struct_name]
+ FTfsample=ft_struct['fsample']
+ fsample=int(FTfsample[0])
+ ifVERBOSE:
+ print('Converting FT fsample array (1x1) to numpy array (1x1)')
+ returnfsample
+
+
+
[docs]defimport_matarray(file_name,array_name,file_version,dim_order,
+ normalise=True):
+ """Read Matlab hdf5 file into IDTxl.
+
+ reads a matlab hdf5 file ("-v7.3' or higher, .mat) with a SINGLE
+ array inside and returns a numpy array with dimensions that
+ are channel x time x trials, using np.swapaxes where necessary
+
+ Note:
+ The import function squeezes the loaded mat-file, i.e., any singleton
+ dimension will be removed. Hence do not enter singleton dimension into
+ the 'dim_order', e.g., don't pass dim_order='ps' but dim_order='s' if
+ you want to load a 1D-array where entries represent samples recorded
+ from a single channel.
+
+ Args:
+ file_name : string
+ full (matlab) file_name on disk
+ array_name : string
+ variable name of the MATLAB structure to be read
+ file_version : string
+ version of the file, e.g. 'v7.3' for MATLAB's 7.3 format, currently
+ versions 'v4', 'v6', 'v7', and 'v7' are supported
+ dim_order : string
+ order of dimensions, accepts any combination of the characters
+ 'p', 's', and 'r' for processes, samples, and replications; must
+ have the same length as the data dimensionality, e.g., 'ps' for a
+ two-dimensional array of data from several processes over time
+ normalise : bool [optional]
+ normalise data after import (default=True)
+
+ Returns:
+ Data() instance
+ instance of IDTxl Data object, containing data from the 'trial'
+ field
+ list of strings
+ list of channel labels, corresponding to the 'label' field
+ numpy array
+ time stamps for samples, corresponding to one entry in the 'time'
+ field
+ int
+ sampling rate, corresponding to the 'fsample' field
+
+ Created on Wed Mar 19 12:34:36 2014
+
+ @author: Michael Wibral
+ """
+ iffile_version=='v7.3':
+ mat_file=h5py.File(file_name)
+ # Assert that at least one of the keys found at the top level of the
+ # HDF file matches the name of the array we wanted
+ ifarray_namenotinmat_file.keys():
+ raiseRuntimeError('Array {0} not in mat file or not a variable '
+ 'at the file''s top level.'.format(array_name))
+
+ # 2. Create an object for the matlab array (from the hdf5 hierachy),
+ # the trailing [()] ensures everything is read
+ mat_data=np.squeeze(np.asarray(mat_file[array_name][()]))
+
+ eliffile_versionin['v4','v6','v7']:
+ try:
+ m=loadmat(file_name,squeeze_me=True,variable_names=array_name)
+ exceptNotImplementedErroraserr:
+ raiseRuntimeError('You may have provided an incorrect file '
+ 'version. The mat file was probably saved as '
+ 'version 7.3 (hdf5).')
+ mat_data=m[array_name]# loadmat returns a dict containing variables
+ else:
+ raiseValueError('Unkown file version: {0}.'.format(file_version))
+
+ # Create output: IDTxl data object, list of labels, sampling info in unit
+ # time steps (sampling rate of 1).
+ print('Creating Data object from matlab array: {0}.'.format(array_name))
+ dat=Data(mat_data,dim_order=dim_order,normalise=normalise)
+ label=[]
+ forninrange(dat.n_processes):
+ label.append('channel_{0:03d}'.format(n))
+ fsample=1
+ timestamps=np.arange(dat.n_samples)
+ returndat,label,timestamps,fsample
Provide functions to load and save IDTxl data, provide import functions (e.g.,mat-files, FieldTrip) and export functions (e.g., networkx, BrainNet Viewer)."""
-# import json
+importjsonimportpickleimporth5pyimportnetworkxasnx
+frompprintimportpprintimportnumpyasnpimportcopyascpimportitertoolsasit
@@ -67,72 +68,58 @@
Source code for idtxl.idtxl_io
'https://pypi.python.org/pypi/networkx/2.0 to export and plot IDTxl ''results in this format.'))
-VERBOSE=False
+DEBUG=False
-# def save(data, file_path):
-# """Save IDTxl data to disk.
+
[docs]defsave_json(d,file_path):
+ """Save dictionary to disk as JSON file.
-# Save different data types to disk. Supported types are:
+ Writes dictionary to disk at the specified file path.
-# - dictionaries with results, e.g., from MultivariateTE
-# - numpy array
-# - instance of IDTXL Data object
+ Args:
+ d : dict
+ dictionary to be written to disk
+ file_path : str
+ path to file (including extension)
-# Note that while numpy arrays and Data instances are saved in binary for
-# performance, dictionaries are saved in the json format, which is human-
-# readable and also easily read into other programs (e.g., MATLAB:
-# http://undocumentedmatlab.com/blog/json-matlab-integration).
+ Note: JSON does not recognize numpy data types, those are converted to
+ basic Python data types first.
+ """
+ data_json=_remove_numpy(d)
+ withopen(file_path,'w')asoutfile:
+ json.dump(obj=data_json,fp=outfile,sort_keys=True)
-# File extensions are
-# - .txt for dictionaries (JSON file)
-# - .npy for numpy array
-# - .npz for Data instances
+
[docs]defload_json(file_path):
+ """Load dictionary saved as JSON file from disk.
-# If the extension is not provided in the file_path, the function will add
-# it depending on the type of the data to be written.
+ Args:
+ file_path : str
+ path to file (including extension)
+ Returns:
+ dict
-# Args:
-# data : dict | numpy array | Data object
-# data to be saved to disk
-# file_path : string
-# string with file name (including the path)
-# """
-# # Check if a file extension is provided in the file_path. Note that the
-# # numpy save functions don't need an extension, they are added if
-# # missing.
-# if file_path.find('.', -4) == -1:
-# add_extension = True
-# else:
-# add_extension = False
-
-# if type(data) is dict:
-# if add_extension:
-# file_path = ''.join([file_path, '.txt'])
-# # JSON does not recognize numpy arrays and data types, they have to
-# # be converted before dumping them.
-# data_json = _remove_numpy(data)
-# if VERBOSE:
-# print('writing file {0}'.format(file_path))
-# with open(file_path, 'w') as outfile:
-# json.dump(obj=data_json, fp=outfile, sort_keys=True)
-# elif type(data) is np.ndarray:
-# # TODO this can't handle scalars, handle this as an exception
-# np.save(file_path, data)
-# elif type(data) is __name__.data.Data:
-# np.savez(file_path, data=data.data, normalised=data.normalise)
+ Note: JSON does not recognize numpy data structures and types. Numpy arrays
+ and data types (float, int) are thus converted to Python types and lists.
+ The loaded dictionary may thus contain different data types than the saved
+ one.
+ """
+ withopen(file_path)asinfile:
+ d=json.load(infile)
+ ifDEBUG:
+ pprint(d)
+ returnd
def_remove_numpy(data):
- """Remove all numpy data structures and types from dictionary.
+ """Replace numpy data types with basic Python types in a dictionary. JSON can not handle numpy types and data structures, they have to be
- convertedto native python types first.
+ converted to native python types first. """data_json=cp.copy(data)forkindata_json.keys():
- ifVERBOSE:
+ ifDEBUG:print('{0}, type: {1}'.format(data_json[k],type(data_json[k])))iftype(data_json[k])isnp.ndarray:data_json[k]=data_json[k].tolist()
@@ -191,6 +178,7 @@
Source code for idtxl.idtxl_io
# return d
+
[docs]defsave_pickle(obj,name):"""Save objects using Python's pickle module.
@@ -198,13 +186,13 @@
Source code for idtxl.idtxl_io
pickle.HIGHEST_PROTOCOL is a binary format, which may be inconvenient, but is good for performance. Protocol 0 is a text format. """
- withopen(name+'.pkl','wb')asf:
+ withopen(name,'wb')asf:pickle.dump(obj,f,pickle.HIGHEST_PROTOCOL)
[docs]defload_pickle(name):"""Load objects that have been saved using Python's pickle module."""
- withopen(name+'.pkl','rb')asf:
+ withopen(name,'rb')asf:returnpickle.load(f)
@@ -300,7 +288,7 @@
Source code for idtxl.idtxl_io
ft_struct=ft_file[ft_struct_name]ft_label=ft_struct['label']
- ifVERBOSE:
+ ifDEBUG:print('Converting FT labels to python list of strings')label=[]
@@ -324,7 +312,7 @@
Source code for idtxl.idtxl_io
ft_file=h5py.File(file_name)ft_struct=ft_file[ft_struct_name]ft_time=ft_struct['time']
- ifVERBOSE:
+ ifDEBUG:print('Converting FT time cell array to numpy array')np_timeaxis_tmp=np.array(ft_file[ft_time[0][0]])
@@ -342,7 +330,7 @@
Source code for idtxl.idtxl_io
ft_struct=ft_file[ft_struct_name]FTfsample=ft_struct['fsample']fsample=int(FTfsample[0])
- ifVERBOSE:
+ ifDEBUG:print('Converting FT fsample array (1x1) to numpy array (1x1)')returnfsample
@@ -436,6 +424,7 @@
Source code for idtxl.idtxl_io
# use 'weights' parameter (string) as networkx edge property name and use# adjacency matrix entries as edge property valuesG=nx.DiGraph()
+ G.add_nodes_from(np.arange(adjacency_matrix.n_nodes()))G.add_weighted_edges_from(adjacency_matrix.get_edge_list(),weights)returnG
'sources have to have the same ''same length')
+ # Check and set defaults for checkpointing.
+ settings=self._set_checkpointing_defaults(
+ settings,data,sources,targets)
+
# Perform MI estimation for each target individuallyresults=ResultsNetworkInference(n_nodes=data.n_processes,n_realisations=data.n_realisations(),
@@ -290,6 +294,12 @@
Source code for idtxl.multivariate_mi
Data.permute_samples() for further settings (default=False) - verbose : bool [optional] - toggle console output (default=True)
+ - write_ckp : bool [optional] - enable checkpointing, writes
+ analysis state to disk every time a variable is selected;
+ resume crashed analysis using
+ network_analysis.resume_checkpoint() (default=False)
+ - filename_ckp : str [optional] - checkpoint file name (without
+ extension) (default='./idtxl_checkpoint') data : Data instance raw data for analysis
@@ -350,23 +360,22 @@
+"""Estimate partial information decomposition (PID).
+
+Estimate PID for multiple sources (up to 4 sources) and one target process
+using SxPID estimator.
+
+Note:
+ Written for Python 3.4+
+"""
+importnumpyasnp
+from.single_process_analysisimportSingleProcessAnalysis
+from.estimatorimportfind_estimator
+from.resultsimportResultsMultivariatePID
+
+
+
[docs]classMultivariatePID(SingleProcessAnalysis):
+ """Perform partial information decomposition for individual processes.
+
+ Perform partial information decomposition (PID) for multiple source
+ processes (up to 4 sources) and a target process in the network.
+ Estimate unique, shared, and synergistic information in the multiple
+ sources about the target. Call analyse_network() on the whole network
+ or a set of nodes or call analyse_single_target() to estimate PID for
+ a single process. See docstrings of the two functions for more information.
+
+ References:
+
+ - Williams, P. L., & Beer, R. D. (2010). Nonnegative Decomposition of
+ Multivariate Information, 1–14. Retrieved from
+ http://arxiv.org/abs/1004.2515
+ - Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure
+ for shared information. 1- 27 Retrieved from
+ http://arxiv.org/abs/2002.03356
+
+ Attributes:
+ target : int
+ index of target process
+ sources : array type
+ multiple of indices of source processes
+ settings : dict
+ analysis settings
+ results : dict
+ estimated PID
+ """
+
+ def__init__(self):
+ super().__init__()
+
+
[docs]defanalyse_network(self,settings,data,targets,sources):
+ """Estimate partial information decomposition for network nodes.
+
+ Estimate, for multiple nodes (target processes), the partial
+ information decomposition (PID) for multiple source processes
+ (up to 4 sources) and each of these target processes
+ in the network.
+
+ Note:
+ For a detailed description of the algorithm and settings see
+ documentation of the analyse_single_target() method and
+ references in the class docstring.
+
+ Example:
+
+ >>> n = 20
+ >>> alph = 2
+ >>> s1 = np.random.randint(0, alph, n)
+ >>> s2 = np.random.randint(0, alph, n)
+ >>> s3 = np.random.randint(0, alph, n)
+ >>> target1 = np.logical_xor(s1, s2).astype(int)
+ >>> target = np.logical_xor(target1, s3).astype(int)
+ >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps',
+ >>> normalise=False)
+ >>> settings = {
+ >>> 'lags_pid': [[1, 1, 1], [3, 2, 7]],
+ >>> 'verbose': False,
+ >>> 'pid_estimator': 'SxPID'}
+ >>> targets = [0, 1]
+ >>> sources = [[1, 2, 3], [0, 2, 3]]
+ >>> pid_analysis = MultivariatePID()
+ >>> results = pid_analysis.analyse_network(settings, data, targets,
+ >>> sources)
+
+ Args:
+ settings : dict
+ parameters for estimation and statistical testing, see
+ documentation of analyse_single_target() for details, can
+ contain
+
+ - lags_pid : list of lists of ints [optional] - lags in samples
+ between sources and target
+ (default=[[1, 1, ..., 1], [1, 1, ..., 1], ...])
+
+ data : Data instance
+ raw data for analysis
+ targets : list of int
+ index of target processes
+ sources : list of lists
+ indices of the multiple source processes for each target, e.g.,
+ [[0, 1, 2], [1, 0, 3]], all must lists be of the same lenght and
+ list of lists must have the same length as targets
+
+ Returns:
+ ResultsMultivariatePID instance
+ results of network inference, see documentation of
+ ResultsMultivariatePID()
+ """
+ # Set defaults for PID estimation.
+ settings.setdefault('verbose',True)
+ settings.setdefault('lags_pid',np.array([[1foriinrange(len(sources[0]))]]*len(targets)))
+
+ # Check inputs.
+ ifnotlen(targets)==len(sources)==len(settings['lags_pid']):
+ raiseRuntimeError('Lists of targets, sources, and lags must have'
+ 'the same lengths.')
+ forlis_1insources:
+ forlis_2insources:
+ ifnotlen(lis_1)==len(lis_2):
+ raiseRuntimeError('Lists in the list sources must have'
+ 'the same lengths.')
+ #^ if
+ #^ for
+ #^ for
+
+ list_of_lags=settings['lags_pid']
+
+ # Perform PID estimation for each target individually
+ results=ResultsMultivariatePID(
+ n_nodes=data.n_processes,
+ n_realisations=data.n_realisations(),
+ normalised=data.normalise)
+ fortinrange(len(targets)):
+ ifsettings['verbose']:
+ print('\n####### analysing target with index {0} from list {1}'
+ .format(t,targets))
+ settings['lags_pid']=list_of_lags[t]
+ res_single=self.analyse_single_target(
+ settings,data,targets[t],sources[t])
+ results.combine_results(res_single)
+ # Get no. realisations actually used for estimation from single target
+ # analysis.
+ results.data_properties.n_realisations=(
+ res_single.data_properties.n_realisations)
+ returnresults
+
+
[docs]defanalyse_single_target(self,settings,data,target,sources):
+ """Estimate partial information decomposition for a network node.
+
+ Estimate partial information decomposition (PID) for multiple source
+ processes (up to 4 sources) and a target process in the network.
+
+ Note:
+ For a description of the algorithm and the method see references in
+ the class and estimator docstrings.
+
+ Example:
+
+ >>> n = 20
+ >>> alph = 2
+ >>> s1 = np.random.randint(0, alph, n)
+ >>> s2 = np.random.randint(0, alph, n)
+ >>> s3 = np.random.randint(0, alph, n)
+ >>> target1 = np.logical_xor(s1, s2).astype(int)
+ >>> target = np.logical_xor(target1, s3).astype(int)
+ >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps',
+ >>> normalise=False)
+ >>> settings = {
+ >>> 'verbose' : false,
+ >>> 'pid_estimator': 'SxPID',
+ >>> 'lags_pid': [2, 3, 1]}
+ >>> pid_analysis = MultivariatePID()
+ >>> results = pid_analysis.analyse_single_target(settings=settings,
+ >>> data=data,
+ >>> target=0,
+ >>> sources=[1, 2, 3])
+
+ Args: settings : dict parameters for estimator use and statistics:
+
+ - pid_estimator : str - estimator to be used for PID estimation
+ (for estimator settings see the documentation in the
+ estimators_pid modules)
+ - lags_pid : list of ints [optional] - lags in samples between
+ sources and target (default=[1, 1, ..., 1])
+ - verbose : bool [optional] - toggle console output
+ (default=True)
+
+ data : Data instance
+ raw data for analysis
+ target : int
+ index of target processes
+ sources : list of ints
+ indices of the multiple source processes for the target
+
+ Returns: ResultsMultivariatePID instance results of
+ network inference, see documentation of
+ ResultsPID()
+ """
+ # Check input and initialise values for analysis.
+ self._initialise(settings,data,target,sources)
+
+ # Estimate PID and significance.
+ self._calculate_pid(data)
+
+ # Add analyis info.
+ results=ResultsMultivariatePID(
+ n_nodes=data.n_processes,
+ n_realisations=data.n_realisations(self.current_value),
+ normalised=data.normalise)
+ results._add_single_result(
+ settings=self.settings,
+ target=self.target,
+ results=self.results)
+ self._reset()
+ returnresults
+
+ def_initialise(self,settings,data,target,sources):
+ """Check input, set initial or default values for analysis settings."""
+ # Check requested PID estimator.
+ try:
+ EstimatorClass=find_estimator(settings['pid_estimator'])
+ exceptKeyError:
+ raiseRuntimeError('Estimator was not specified!')
+ self._pid_estimator=EstimatorClass(settings)
+
+ self.settings=settings.copy()
+ self.settings.setdefault('lags_pid',[1foriinrange(len(sources))])
+ self.settings.setdefault('verbose',True)
+
+ # Check if provided lags are correct and work with the number of
+ # samples in the data.
+ iflen(self.settings['lags_pid'])notin[2,3,4]:
+ raiseRuntimeError('List of lags must have length 2 or 3 or 4.')
+ # number of lags is equal to number of sources
+ ifnotlen(self.settings['lags_pid'])==len(sources):
+ raiseRuntimeError('List of lags must have same length as the list sources.')
+ foriinrange(len(self.settings['lags_pid'])):
+ ifself.settings['lags_pid'][0]>=data.n_samples:
+ raiseRuntimeError(
+ 'Lag {0} ({1}) is larger than the number of samples in the data '
+ 'set ({2}).'.format(
+ i,self.settings['lags_pid'][i],data.n_samples))
+
+ # Check if target and sources are provided correctly.
+ iftype(target)isnotint:
+ raiseRuntimeError('Target must be an integer.')
+ iflen(sources)notin[2,3,4]:
+ raiseRuntimeError('List of sources must have length 2 or 3 or 4.')
+ iftargetinsources:
+ raiseRuntimeError('The target ({0}) should not be in the list '
+ 'of sources ({1}).'.format(target,sources))
+
+ self.current_value=(target,max(self.settings['lags_pid']))
+ self.target=target
+ # TODO works for single vars only, change to multivariate?
+ self.sources=self._lag_to_idx([
+ (sources[i],self.settings['lags_pid'][i])
+ foriinrange(len(sources))])
+
+ def_calculate_pid(self,data):
+
+ # TODO Discuss how and if the following statistical testing should be
+ # included included. Remove dummy results.
+ # [orig_pid, sign_1, p_val_1,
+ # sign_2, p_val_2] = stats.unq_against_surrogates(self, data)
+ # [orig_pid, sign_shd,
+ # p_val_shd, sign_syn, p_val_syn] = stats.syn_shd_against_surrogates(
+ # self,
+ # sign_1 = sign_2 = sign_shd = sign_syn = False
+ # p_val_1 = p_val_2 = p_val_shd = p_val_syn = 1.0
+
+ target_realisations=data.get_realisations(
+ self.current_value,
+ [self.current_value])[0]
+
+ # CHECK! make sure self.source has the same idx as sources
+ data.get_realisations(self.current_value,[self.sources[0]])[0]
+ list_sources_var_realisations=[data.get_realisations(
+ self.current_value,
+ [self.sources[i]])[0]
+ foriinrange(len(self.sources))]
+
+
+ orig_pid=self._pid_estimator.estimate(
+ s=list_sources_var_realisations,
+ t=target_realisations)
+
+
+ self.results=orig_pid
+ foriinrange(len(self.sources)):
+ self.results['source_'+str(i+1)]=self._idx_to_lag([self.sources[i]])
+ #^ for
+ self.results['selected_vars_sources']=[
+ self.results['source_'+str(i+1)][0]foriinrange(len(self.sources))]
+ self.results['current_value']=self.current_value
+ # self.results['unq_s1_sign'] = sign_1
+ # self.results['unq_s2_sign'] = sign_2
+ # self.results['unq_s1_p_val'] = p_val_1
+ # self.results['unq_s2_p_val'] = p_val_2
+ # self.results['syn_sign'] = sign_syn
+ # self.results['syn_p_val'] = p_val_syn
+ # self.results['shd_sign'] = sign_shd
+ # self.results['shd_p_val'] = p_val_shd
+
+ # TODO make mi_against_surrogates in stats more generic, such that
+ # it becomes an arbitrary permutation test where one arguemnt gets
+ # shuffled and then all arguents are passed to the provided estimator
+
+ def_reset(self):
+ """Reset instance after analysis."""
+ self.__init__()
+ delself.results
+ delself.settings
+ delself._pid_estimator
http://doi.org/10.1103/PhysRevE.83.051112 Attributes:
-
source_set : list indices of source processes tested for their influence on the target
@@ -192,6 +191,10 @@
Source code for idtxl.multivariate_te
'sources have to have the same ''same length')
+ # Check and set defaults for checkpointing.
+ settings=self._set_checkpointing_defaults(
+ settings,data,sources,targets)
+
# Perform TE estimation for each target individuallyresults=ResultsNetworkInference(n_nodes=data.n_processes,n_realisations=data.n_realisations(),
@@ -296,6 +299,12 @@
Source code for idtxl.multivariate_te
Data.permute_samples() for further settings (default=False) - verbose : bool [optional] - toggle console output (default=True)
+ - write_ckp : bool [optional] - enable checkpointing, writes
+ analysis state to disk every time a variable is selected;
+ resume crashed analysis using
+ network_analysis.resume_checkpoint() (default=False)
+ - filename_ckp : str [optional] - checkpoint file name (without
+ extension) (default='./idtxl_checkpoint') data : Data instance raw data for analysis
@@ -352,26 +361,34 @@
"""Parent class for network inference and network comparison."""
+importos.path
+fromdatetimeimportdatetime
+fromshutilimportcopyfile
+frompprintimportpprint
+importastimportcopyascpimportitertoolsasitimportnumpyasnpfrom.estimatorimportfind_estimatorfrom.importidtxl_utilsasutils
+from.importidtxl_ioasio
cond=self.settings['add_conditionals']iftype(cond)istuple:# easily add single variablecond=[cond]
- cond_idx=self._lag_to_idx(cond)
+ eliftype(cond)isdict:# add conditioning variables per target
+ try:
+ cond=cond[self.target]
+ exceptKeyError:
+ return# no additional variables for the current target
+ cond_idx=self._lag_to_idx(cond)candidate_set=list(set(candidate_set).difference(set(cond_idx)))returncandidate_set
@@ -457,7 +468,7 @@
Source code for idtxl.network_analysis
Returns: numpy array estimate of dependency measure for each link
-
+
Raises: ex.AlgorithmExhaustedError Raised from estimate() when calculation cannot be made
@@ -544,26 +555,176 @@
+ returnlinks
+
+ def_set_checkpointing_defaults(self,settings,data,sources,target):
+ """Set defaults for writing analysis checkpoints."""
+ settings.setdefault('write_ckp',False)
+ ifsettings['write_ckp']:
+ settings.setdefault('filename_ckp','./idtxl_checkpoint')
+ filename_ckp='{0}.ckp'.format(settings['filename_ckp'])
+ ifnotos.path.isfile(filename_ckp):
+ self._initialise_checkpoint(settings,data,sources,target)
+ returnsettings
+ else:
+ returnsettings
+
+ def_initialise_checkpoint(self,settings,data,sources,targets):
+ """Write first checkpoint file, data, and settings to disk.
+
+ Called once at the beggining of an analysis using checkpointing. Write
+ data and analysis settings to disk. This needs to be done only once.
+ Initialise checkpoint file: write header with time stamp, path to data
+ and settings, and targets and sources to be analysed. The checkpoint
+ file is updated during the analyis.
+ """
+ # Check if targets is an int, convert to array.
+ iftype(targets)isint:
+ targets=[targets]
+ # Write data to disk.
+ io.save_pickle(data,
+ '{0}.dat'.format(settings['filename_ckp']))
+ # Write settings to disk.
+ io.save_json(settings,
+ '{0}.json'.format(settings['filename_ckp']))
+
+ # Initialise checkpoint file for later updates.
+ filename_ckp='{0}.ckp'.format(settings['filename_ckp'])
+ withopen(filename_ckp,'w')astext_file:
+ text_file.write('IDTxl checkpoint file.\n')
+ timestamp=datetime.now()
+ text_file.write('{:%Y-%m-%d %H:%M:%S}\n'.format(timestamp))
+ text_file.write('Raw data path: {}.dat\n'.format(
+ os.path.abspath(settings['filename_ckp'])))
+ text_file.write('Settings path: {}.json\n'.format(
+ os.path.abspath(settings['filename_ckp'])))
+ text_file.write('Targets to be analyzed: {}\n'.format(targets))
+ text_file.write('Sources to be analyzed: {}\n\n'.format(sources))
+ text_file.write(
+ 'Selected variables (target: [sources]: [selected variables]):'
+ '\n{}'.format(targets[0]))
+
+ def_write_checkpoint(self):
+ """Write checkpoint to disk.
+
+ Write checkpoint to disk. The checkpoint contains variables already
+ selected by network analysis algorithms. To recover from a checkpoint
+ use the 'recover_checkpoint()‘ method.
+
+ Note: IDTxl will always keep the current (*.ckp) and the previous
+ version (*.ckp.old) of the checkpoint file to ensure a recoverable
+ state even if writing of the current checkpoint fails.
+ """
+ filename_ckp='{0}.ckp'.format(self.settings['filename_ckp'])
+
+ # Check if a checkpoint file already exists. If yes,
+ # 1. make a copy using the same file name plus the .old extension
+ # (overwriting the last *.ckp.old file);
+ # 2. update current checkpoint file.
+ ifos.path.isfile(filename_ckp):
+ copyfile(filename_ckp,'{}.old'.format(filename_ckp))
+ self._update_checkpoint(filename_ckp)
+ else:
+ raiseRuntimeError('Could not find checkpoint file for updating. '
+ 'Initialise checkpoint first.')
+
+ def_update_checkpoint(self,filename_ckp):
+ """Update existing checkpoint file.
+
+ Add the last selected variable to the *.ckp file while keeping the
+ path to data and settings. Overwrite time stamp in header.
+ """
+ # We don't expect these files to become very big. Hence, it is the
+ # easiest to load the whole file into a data structure and then write
+ # it back (https://stackoverflow.com/a/328007). Alternatively, we can
+ # just add the last selected variable as a tuple -> then we have to
+ # make sure, the last selected candidate always ends up at the end of
+ # the selected candidates list.
+
+ # Write time stamp and info
+ timestamp=datetime.now()
+ # Convert absolute indices to lags with respect to the current value.
+ selected_variables=self._idx_to_lag(self.selected_vars_full,
+ self.current_value[1])
+ # Read file as list of lines and replace first and last line. Write
+ # modified file back to disk.
+ withopen(filename_ckp,'r')asf:
+ lines=f.readlines()
+ lines[1]='{:%Y-%m-%d %H:%M:%S}\n'.format(timestamp)
+ ifint(lines[-1][0])==self.target:
+ lines[-1]='{0}: {1}: {2}\n'.format(
+ self.target,self.source_set,selected_variables)
+ else:
+ lines.append('{0}: {1}: {2}\n'.format(
+ self.target,self.source_set,selected_variables))
+ withopen(filename_ckp,'w')asf:
+ f.writelines(lines)
+
+
[docs]defresume_checkpoint(self,file_path):
+ """Resume analysis from a checkpoint saved to disk.
+
+ Args:
+ file_path : str
+ path to checkpoint file (excluding extension: *.ckp)
+ """
+
+ # Read checkpoint
+ withopen('{}.ckp'.format(file_path),'r')asf:
+ lines=f.readlines()
+ timestamp=lines[1]
+ data_path=lines[2][15:].strip()
+ settings_path=lines[3][15:].strip()
+ # Load settings and data
+ data=io.load_pickle(data_path)
+ settings=io.load_json(settings_path)
+ verbose=settings.get('verbose',True)
+ ifverbose:
+ print('Resuming analysis from file {}.ckp, saved {}'.format(
+ file_path,timestamp))
+ # Read targets and sources.
+ targets=ast.literal_eval(lines[4].split(':')[1].strip())
+ sources=ast.literal_eval(lines[5].split(':')[1].strip())
+ # Read selected variables
+ # Format: target - sources analyzed - selected variables
+ selected_variables={}# vars as lags wrt. the current value
+ forlinrange(8,len(lines)):
+ result=[x.strip()forxinlines[l].split(':')]
+ # ast.literal_eval(result[2]): IndexError: list index out of range
+ try:
+ selected_variables[int(result[0])]=ast.literal_eval(result[2])
+ exceptIndexError:
+ ifverbose:
+ print('No variables previously selected.')
+
+ ifverbose:
+ print('Selected variables per target:')
+ pprint(selected_variables)
+
+ # Add already selected candidates as conditionals to be added to the
+ # settings dict. Note that the time stamp in the selected variables
+ # list is a lag wrt. the current value. This format is also expected by
+ # the method that manually adds conditionals.
+ settings['add_conditionals']=selected_variables
+
+ returndata,settings,targets,sources
# Main comparison.print('\n-------------------------- (1) create union of networks')
+ network_all=np.hstack((network_set_a,network_set_b))self._create_union(*network_all)print('\n-------------------------- (2) calculate differences in TE ''values')
@@ -504,7 +505,7 @@
Source code for idtxl.network_comparison
# Compare raw TE values between conditions.self.cmi_comp=self._compare_union_cmi_within(cmi_a,cmi_b)
- def_calculate_cmi_diff_between(self,data_set_a,data_set_b):
+ def_calculate_cmi_diff_between(self):"""Calculate the difference in CMI between two groups of subjects. Calculate the difference in the conditional mutual information (CMI)
@@ -1026,23 +1027,22 @@
"""success=Falseifself.settings['verbose']:
- print('candidate set: {0}'.format(
- self._idx_to_lag(candidate_set)))
+ print('candidate set: {0}'.format(
+ self._idx_to_lag(candidate_set)))whilecandidate_set:# Get realisations for all candidates.cand_real=data.get_realisations(self.current_value,
@@ -167,9 +167,9 @@
Source code for idtxl.network_inference
# The algorithm cannot continue here, so# we'll terminate the search for more candidates,# though those identified already remain valid
- print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
- print('Halting current estimation set.')
+ print(
+ 'AlgorithmExhaustedError encountered in estimations: {}. '
+ 'Halting current estimation set.'.format(aee.message))# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
@@ -181,14 +181,14 @@
Source code for idtxl.network_inference
print('testing candidate: {0} '.format(self._idx_to_lag([max_candidate])[0]),end='')try:
- significant=stats.max_statistic(self,data,candidate_set,
- te_max_candidate)[0]
+ significant=stats.max_statistic(
+ self,data,candidate_set,te_max_candidate)[0]exceptex.AlgorithmExhaustedErrorasaee:
- # The algorithm cannot continue here, so
- # we'll terminate the check of significance for this candidate,
- # though those identified already remain valid
+ # The algorithm cannot continue here, so we'll terminate the
+ # check of significance for this candidate, though those
+ # identified already remain validprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting candidate max stats test')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)
@@ -206,6 +206,8 @@
returnsuccessdef_force_conditionals(self,cond,data):
- """Enforce a given conditioning set."""
+ """Enforce a given conditioning set.
+
+ Manually add variables to the conditioning set before analysis. Added
+ variables are not tested in the inclusion step of the algorithm, but
+ are tested in the pruning step and may be removed there. Source and
+ target past and current variables can be included.
+
+ Args:
+ cond : str | dict | list | tuple
+ variables added to the conditioning set, 'faes' adds all source
+ variables with zero-lag to condition out shared information due
+ to instantaneous mixing, a dict can contain a list of variables
+ for each target ({target ind: [(source ind, lag), ...]}), a list
+ of the same variables added for each target ([(source ind, lag),
+ ...]), a tuple with a single variable that is added for each
+ target
+ data : Data instance
+ input data
+ """iftype(cond)isstr:# Get realisations and indices of source variables with lag 0. Note# that _define_candidates returns tuples with absolute indices and
@@ -229,6 +249,11 @@
Source code for idtxl.network_inference
# lags to absolute sample indices and add variables.iftype(cond)istuple:# easily add single variablecond=[cond]
+ eliftype(cond)isdict:# add conditioning variables per target
+ try:
+ cond=cond[self.target]
+ exceptKeyError:
+ return# no additional variables for the current targetprint('Adding the following variables to the conditioning set: ''{0}.'.format(cond))cond_idx=self._lag_to_idx(cond)
@@ -323,6 +348,11 @@
Source code for idtxl.network_inference
# user. This tests if there is sufficient data to do all tests.# surrogates.check_permutations(self, data)
+ # Check and set defaults for checkpointing. If requested, initialise
+ # checkpointing.
+ self.settings=self._set_checkpointing_defaults(
+ self.settings,data,sources,target)
+
# Reset all attributes to inital values if the instance of# MultivariateTE has been used before.ifself.selected_vars_full:
@@ -441,6 +471,11 @@
Source code for idtxl.network_inference
# user. This tests if there is sufficient data to do all tests.# surrogates.check_permutations(self, data)
+ # Check and set defaults for checkpointing. If requested, initialise
+ # checkpointing.
+ self.settings=self._set_checkpointing_defaults(
+ self.settings,data,sources,target)
+
# Reset all attributes to inital values if the instance of# MultivariateTE has been used before.ifself.selected_vars_full:
@@ -538,11 +573,12 @@
Source code for idtxl.network_inference
forsourceinself.source_set:candidate_set=self._define_candidates([source],samples)ifself.settings['verbose']:
- print('candidate set current source: {0}\n'.format(
- self._idx_to_lag(candidate_set)),end='')
+ print('candidate set current source: {0}\n'.format(
+ self._idx_to_lag(candidate_set)),end='')# Initialise conditional realisations. This gets updated if sources
- # are selected in the iterative conditioning.
+ # are selected in the iterative conditioning. For MI calculation
+ # this is None.conditional_realisations=conditional_realisations_targetwhilecandidate_set:
@@ -566,7 +602,7 @@
Source code for idtxl.network_inference
# we'll terminate the search for more candidates,# though those identified already remain validprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting current estimation set.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)
@@ -584,10 +620,10 @@
Source code for idtxl.network_inference
te_max_candidate,conditional_realisations)[0]exceptex.AlgorithmExhaustedErrorasaee:# The algorithm cannot continue here, so
- # we'll terminate the significance check for this candidate,
- # though those identified already remain valid
+ # we'll terminate the significance check for this
+ # candidate, though those identified already remain valid.print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting candidate max stats test')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)
@@ -611,6 +647,9 @@
conditional_realisations_target=(self._selected_vars_target_realisations)cond_target_dim=conditional_realisations_target.shape[1]
+
# Prune all selected sources separately. This way, the conditioning# uses past variables from the current source only (opposed to past# variables from all sources as in multivariate network inference).
@@ -650,7 +690,8 @@
Source code for idtxl.network_inference
[s[0]forsinself.selected_vars_sources])forsourceinsignificant_sources:# Find selected past variables for current source
- print('selected vars sources {0}'.format(self.selected_vars_sources))
+ print('selected vars sources {0}'.format(
+ self.selected_vars_sources))source_vars=[sforsinself.selected_vars_sourcesifs[0]==source]print('selected candidates current source: {0}'.format(
@@ -660,7 +701,7 @@
Source code for idtxl.network_inference
# maximum statistic for this variable.iflen(source_vars)==1:ifself.settings['verbose']:
- print(' -- significant')
+ print(' -- significant')continue# Find the candidate with the minimum TE/MI into the target.
@@ -712,10 +753,10 @@
Source code for idtxl.network_inference
# The algorithm cannot continue here, so# we'll terminate the pruning check,# assuming that we need not prune any more
- print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
- print('Halting current pruning and allowing others to'
- ' remain.')
+ print(
+ 'AlgorithmExhaustedError encountered in estimations: '
+ '{}. Halting current estimation set.'.format(
+ aee.message))# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
@@ -752,9 +793,9 @@
Source code for idtxl.network_inference
# we'll terminate the pruning check,# assuming that we need not prune any moreprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting current pruning and allowing others to'
- ' remain.')
+ ' remain.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
@@ -768,6 +809,8 @@
# The algorithm cannot continue here, so# we'll set the results to zeroprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting omnibus test and setting to not significant.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)
@@ -834,7 +877,7 @@
Source code for idtxl.network_inference
# it seems ok to let everything through still but# just write a 0 for final valuesprint('AlgorithmExhaustedError encountered in '
- 'final_conditional estimations: '+aee.message)
+ 'final_conditional estimations: '+aee.message)print('Halting final_conditional estimations')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)
@@ -939,9 +982,9 @@
Source code for idtxl.network_inference
# we'll terminate the pruning check,# assuming that we need not prune any moreprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting current pruning and allowing others to'
- ' remain.')
+ ' remain.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
@@ -970,9 +1013,9 @@
Source code for idtxl.network_inference
# we'll terminate the pruning check,# assuming that we need not prune any moreprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: '+aee.message)print('Halting current pruning and allowing others to'
- ' remain.')
+ ' remain.')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)break
@@ -986,6 +1029,8 @@
exceptex.AlgorithmExhaustedErrorasaee:# The algorithm cannot continue here, so# we'll set the results to zero
- print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
- print('Halting omnibus test and setting to not significant.')
+ print(
+ 'AlgorithmExhaustedError encountered in estimations: {}. '
+ 'Halting current estimation set.'.format(aee.message))# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)stat=0
@@ -1035,10 +1080,6 @@
Source code for idtxl.network_inference
self.statistic_sign_sources=stat# Calculate TE for all links in the network. Calculate local TE# if requested by the user.
- ifself.measure=='te':
- conditioning='target'
- elifself.measure=='mi':
- conditioning='none'try:self.statistic_single_link=self._calculate_single_link(data=data,
@@ -1046,7 +1087,7 @@
Source code for idtxl.network_inference
source_vars=self.selected_vars_sources,target_vars=self.selected_vars_target,sources='all',
- conditioning=conditioning)
+ conditioning='full')exceptex.AlgorithmExhaustedErrorasaee:# The algorithm cannot continue here, so# we'll terminate the computation of single link stats.
@@ -1054,7 +1095,7 @@
Source code for idtxl.network_inference
# it seems ok to let everything through still but# just write a 0 for final valuesprint('AlgorithmExhaustedError encountered in '
- 'final_conditional estimations: '+aee.message)
+ 'final_conditional estimations: '+aee.message)print('Halting final_conditional estimations')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)
@@ -1068,23 +1109,22 @@
MIN_INT=-sys.maxsize-1# minimum integer for initializing adj. matrix
-
[docs]classDotDict(dict):
+classDotDict(dict):"""Dictionary with dot-notation access to values. Provides the same functionality as a regular dict, but also allows
@@ -95,42 +95,42 @@
Source code for idtxl.results
def__setstate__(self,state):# For un-pickling the object
- self.update(state)
+ self.update(state)# self.__dict__ = self
-
[docs]classAdjacencyMatrix():
+classAdjacencyMatrix():"""Adjacency matrix representing inferred networks."""def__init__(self,n_nodes,weight_type):self._edge_matrix=np.zeros((n_nodes,n_nodes),dtype=bool)self._weight_matrix=np.zeros((n_nodes,n_nodes),dtype=weight_type)ifnp.issubdtype(weight_type,np.integer):self._weight_type=np.integer
- elifnp.issubdtype(weight_type,np.float):
- self._weight_type=np.float
+ elifnp.issubdtype(weight_type,np.floating):
+ self._weight_type=np.floatingelifweight_typeisbool:self._weight_type=weight_typeelse:raiseRuntimeError('Unknown weight data type {0}.'.format(weight_type))
-
[docs]defn_nodes(self):
+ defn_nodes(self):"""Return number of nodes."""
- returnself._edge_matrix.shape[0]
[docs]defadd_edge(self,i,j,weight):
+ defadd_edge(self,i,j,weight):"""Add weighted edge (i, j) to adjacency matrix."""ifnotnp.issubdtype(type(weight),self._weight_type):raiseTypeError('Can not add weight of type {0} to adjacency matrix of type ''{1}.'.format(type(weight),self._weight_type))self._edge_matrix[i,j]=True
- self._weight_matrix[i,j]=weight
[docs]classResults():
+classResults():"""Parent class for results of network analysis algorithms. Provide a container for results of network analysis algorithms, e.g.,
@@ -229,7 +229,7 @@
Source code for idtxl.results
else:returnFalse
-
[docs]defcombine_results(self,*results):
+ defcombine_results(self,*results):"""Combine multiple (partial) results objects. Combine a list of partial network analysis results into a single
@@ -282,8 +282,7 @@
Source code for idtxl.results
raiseAttributeError('Did not find any method attributes to combine ''(.single_proces or ._single_target).')
- self._add_single_result(p,results_to_add,r.settings)
[docs]defget_single_target(self,target,fdr=True):
+ defget_single_target(self,target,fdr=True):"""Return results for a single target in the network. Return results for individual processes, contains for each process
@@ -523,9 +522,9 @@
Source code for idtxl.results
raiseRuntimeError('No results have been added.')exceptKeyError:raiseRuntimeError(
- 'No results for target {0}.'.format(target))
+ 'No results for target {0}.'.format(target))
-
[docs]defget_target_sources(self,target,fdr=True):
+ defget_target_sources(self,target,fdr=True):"""Return list of sources (parents) for given target. Args:
@@ -536,7 +535,7 @@
[docs]classResultsPartialInformationDecomposition(ResultsNetworkAnalysis):
+classResultsPID(ResultsNetworkAnalysis):"""Store results of Partial Information Decomposition (PID) analysis. Provide a container for results of Partial Information Decomposition (PID)
@@ -795,7 +794,7 @@
[docs]defget_single_target(self,target):
+ defget_single_target(self,target):"""Return results for a single target in the network. Results for single targets include for each target
@@ -823,8 +822,78 @@
Source code for idtxl.results
(result['selected_vars_sources']) or via dot-notation (result.selected_vars_sources). """
- returnsuper(ResultsPartialInformationDecomposition,
- self).get_single_target(target,fdr=False)
+ returnsuper(ResultsPID,
+ self).get_single_target(target,fdr=False)
+
+
+classResultsMultivariatePID(ResultsNetworkAnalysis):
+ """Store results of Multivariate Partial Information Decomposition (PID)
+analysis.
+
+ Provide a container for results of Multivariate Partial Information
+ Decomposition (PID) algorithms.
+
+ Note that for convenience all dictionaries in this class can additionally
+ be accessed using dot-notation:
+
+ >>> res_pid._single_target[2].source_1
+
+ or
+
+ >>> res_pid._single_target[2].['source_1'].
+
+ Attributes:
+ settings : dict
+ settings used for estimation of information theoretic measures and
+ statistical testing
+ data_properties : dict
+ data properties, contains
+
+ - n_nodes : int - total number of nodes in the network
+ - n_realisations : int - number of samples available for
+ analysis given the settings (e.g., a high maximum lag used in
+ network inference, results in fewer data points available for
+ estimation)
+ - normalised : bool - indicates if data were z-standardised
+ before the estimation
+
+ targets_analysed : list
+ list of analysed targets
+ """
+
+ def__init__(self,n_nodes,n_realisations,normalised):
+ super().__init__(n_nodes,n_realisations,normalised)
+
+ defget_single_target(self,target):
+ """Return results for a single target in the network.
+
+ Results for single targets include for each target
+
+ - source_i : tuple - source variable i
+ - selected_vars_sources : list of tuples - source variables used in PID
+ estimation
+ - avg : dict - avg pid {alpha -> float} where alpha is a redundancy
+ lattice node
+ - ptw : dict of dicts - ptw pid {rlz -> {alpha -> float} } where rlz is
+ a single realisation of the random variables and alpha is a redundancy
+ lattice node
+ - current_value : tuple - current value used for analysis, described by
+ target and sample index in the data
+ - [estimator-specific settings]
+
+ Args:
+ target : int
+ target id
+
+ Returns:
+ dict
+ Results for single target. Note that for convenience
+ dictionary entries can either be accessed via keywords
+ (result['selected_vars_sources']) or via dot-notation
+ (result.selected_vars_sources).
+ """
+ returnsuper(ResultsMultivariatePID,
+ self).get_single_target(target,fdr=False)
transfer entropy value to be tested conditional : numpy array [optional] realisations of conditional, 2D numpy array where array dimensions
- represent [realisations x variable dimension] (per default all
- already selected source and target variables from the
- analysis_setup are used)
+ represent [realisations x variable dimension] (default=None, no
+ conditioning performed) Returns: bool
@@ -481,7 +480,8 @@
Source code for idtxl.stats
Raises: ex.AlgorithmExhaustedError
- Raised from _create_surrogate_table() when calculation cannot be made
+ Raised from _create_surrogate_table() when calculation cannot be
+ made """# Set defaults and get parameters from settings dictionaryanalysis_setup.settings.setdefault('n_perm_max_stat',200)
@@ -494,6 +494,7 @@
numpy array, float TE values for individual sources """
- try:
- n_permutations=analysis_setup.settings['n_perm_max_seq']
- exceptKeyError:
- try:# use the same n_perm as for min_stats if surr table is reused
- n_permutations=analysis_setup._min_stats_surr_table.shape[1]
- analysis_setup.settings['n_perm_max_seq']=n_permutations
- exceptAttributeError:# is surr table is None, use default
- analysis_setup.settings['n_perm_max_seq']=500
- n_permutations=analysis_setup.settings['n_perm_max_seq']
+ # Set defaults and get test parameters.
+ analysis_setup.settings.setdefault('n_perm_max_seq',500)
+ n_permutations=analysis_setup.settings['n_perm_max_seq']analysis_setup.settings.setdefault('alpha_max_seq',0.05)alpha=analysis_setup.settings['alpha_max_seq']_check_permute_in_time(analysis_setup,data,n_permutations)
+ permute_in_time=analysis_setup.settings['permute_in_time']
+
ifanalysis_setup.settings['verbose']:
- print('sequential maximum statistic, n_perm: {0}'.format(
- n_permutations))
+ print('sequential maximum statistic, n_perm: {0}, testing {1} selected'
+ ' sources'.format(n_permutations,
+ len(analysis_setup.selected_vars_sources)))assertanalysis_setup.selected_vars_sources,'No sources to test.'
@@ -584,29 +582,73 @@
Source code for idtxl.stats
# Calculate TE for each candidate in the conditional source set, i.e.,# calculate the conditional MI between each candidate and the current
- # value, conditional on all selected variables in the conditioning set.
- # Then sort the estimated TE values.
+ # value, conditional on all selected variables in the conditioning set,
+ # excluding the current source. Calculate surrogates for each candidate by
+ # shuffling the candidate realisations n_perm times. Afterwards, sort the
+ # estimated TE values.i_1=0i_2=data.n_realisations(analysis_setup.current_value)
+ surr_table=np.zeros((len(analysis_setup.selected_vars_sources),
+ n_permutations))# Collect data for each candidate and the corresponding conditioning set.
- forcandidateinanalysis_setup.selected_vars_sources:
- [temp_cond,temp_cand]=analysis_setup._separate_realisations(
- idx_conditional,
- candidate)
+ # Use realisations for parallel estimation of the test statistic later.
+ foridx_c,candidateinenumerate(analysis_setup.selected_vars_sources):
+ [conditional_realisations_current,
+ candidate_realisations_current]=analysis_setup._separate_realisations(
+ idx_conditional,candidate)# The following may happen if either the requested conditing is 'none'# or if the conditiong set that is tested consists only of a single# candidate.
- iftemp_condisNone:
+ ifconditional_realisations_currentisNone:conditional_realisations=Nonere_use=['var2','conditional']else:
- conditional_realisations[i_1:i_2,]=temp_cond
+ conditional_realisations[i_1:i_2,]=conditional_realisations_currentre_use=['var2']
- candidate_realisations[i_1:i_2,]=temp_cand
+ candidate_realisations[i_1:i_2,]=candidate_realisations_currenti_1=i_2i_2+=data.n_realisations(analysis_setup.current_value)
+ # Generate surrogates for the current candidate.
+ if(analysis_setup._cmi_estimator.is_analytic_null_estimator()and
+ permute_in_time):
+ # Generate the surrogates analytically
+ surr_table[idx_c,:]=(
+ analysis_setup._cmi_estimator.estimate_surrogates_analytic(
+ n_perm=n_permutations,
+ var1=data.get_realisations(analysis_setup.current_value,
+ [candidate])[0],
+ var2=analysis_setup._current_value_realisations,
+ conditional=conditional_realisations_current))
+ else:
+ analysis_setup.settings['analytical_surrogates']=False
+ surr_candidate_realisations=_get_surrogates(
+ data,
+ analysis_setup.current_value,
+ [candidate],
+ n_permutations,
+ analysis_setup.settings)
+ try:
+ surr_table[idx_c,:]=(
+ analysis_setup._cmi_estimator.estimate_parallel(
+ n_chunks=n_permutations,
+ re_use=['var2','conditional'],
+ var1=surr_candidate_realisations,
+ var2=analysis_setup._current_value_realisations,
+ conditional=conditional_realisations_current))
+ exceptex.AlgorithmExhaustedErrorasaee:
+ # The aglorithm cannot continue here, so
+ # we'll terminate the max sequential stats test,
+ # and declare all not significant
+ print('AlgorithmExhaustedError encountered in estimations: {}.'.format(
+ aee.message))
+ print('Stopping sequential max stats at candidate with rank 0')
+ return \
+ (np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
+ np.ones(len(analysis_setup.selected_vars_sources)),
+ np.zeros(len(analysis_setup.selected_vars_sources)))
+
# Calculate original statistic (multivariate/bivariate TE/MI)try:individual_stat=analysis_setup._cmi_estimator.estimate_parallel(
@@ -619,48 +661,19 @@
Source code for idtxl.stats
# The aglorithm cannot continue here, so# we'll terminate the max sequential stats test,# and declare all not significant
- print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ print('AlgorithmExhaustedError encountered in estimations: {}.'.format(
+ aee.message))print('Stopping sequential max stats at candidate with rank 0')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)# Return (signficance, pvalue, TEs):return \
(np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
- np.ones(len(analysis_setup.selected_vars_sources)),
- np.zeros(len(analysis_setup.selected_vars_sources)))
+ np.ones(len(analysis_setup.selected_vars_sources)),
+ np.zeros(len(analysis_setup.selected_vars_sources)))selected_vars_order=utils.argsort_descending(individual_stat)individual_stat_sorted=utils.sort_descending(individual_stat)
-
- # Re-use surrogate table from previous pruning using min stats, if it
- # already exists. This saves some time. Otherwise create surrogate table.
- # Sort surrogate table.
- if(analysis_setup._min_stats_surr_tableisnotNoneand
- n_permutations<=analysis_setup._min_stats_surr_table.shape[1]):
- surr_table=analysis_setup._min_stats_surr_table[:,:n_permutations]
- assertlen(analysis_setup.selected_vars_sources)==surr_table.shape[0]
- else:
- try:
- surr_table=_create_surrogate_table(
- analysis_setup=analysis_setup,
- data=data,
- idx_test_set=analysis_setup.selected_vars_sources,
- n_perm=n_permutations)
- exceptex.AlgorithmExhaustedErrorasaee:
- # The aglorithm cannot continue here, so
- # we'll terminate the max sequential stats test,
- # and declare all not significant
- print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
- print('Stopping sequential max stats at candidate with rank 0')
- # For now we don't need a stack trace:
- # traceback.print_tb(aee.__traceback__)
- # Return (signficance, pvalue, TEs):
- return \
- (np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
- np.ones(len(analysis_setup.selected_vars_sources)),
- np.zeros(len(analysis_setup.selected_vars_sources)))max_distribution=_sort_table_max(surr_table)# Compare each original value with the distribution of the same rank,
@@ -735,21 +748,18 @@
Source code for idtxl.stats
numpy array, float TE values for individual sources """
- try:
- n_permutations=analysis_setup.settings['n_perm_max_seq']
- exceptKeyError:
- try:# use the same n_perm as for min_stats if surr table is reused
- n_permutations=analysis_setup._min_stats_surr_table.shape[1]
- analysis_setup.settings['n_perm_max_seq']=n_permutations
- exceptAttributeError:# is surr table is None, use default
- analysis_setup.settings['n_perm_max_seq']=500
- n_permutations=analysis_setup.settings['n_perm_max_seq']
+ # Set defaults and get test parameters.
+ analysis_setup.settings.setdefault('n_perm_max_seq',500)
+ n_permutations=analysis_setup.settings['n_perm_max_seq']analysis_setup.settings.setdefault('alpha_max_seq',0.05)alpha=analysis_setup.settings['alpha_max_seq']_check_permute_in_time(analysis_setup,data,n_permutations)
+ permute_in_time=analysis_setup.settings['permute_in_time']
+
ifanalysis_setup.settings['verbose']:
- print('sequential maximum statistic, n_perm: {0}'.format(
- n_permutations))
+ print('sequential maximum statistic, n_perm: {0}, testing {1} selected'
+ ' sources'.format(n_permutations,
+ len(analysis_setup.selected_vars_sources)))assertanalysis_setup.selected_vars_sources,'No sources to test.'
@@ -792,16 +802,17 @@
Source code for idtxl.stats
# conditioning set. Then sort the estimated TE/MI values.i_1=0i_2=data.n_realisations(analysis_setup.current_value)
+ surr_table=np.zeros((len(source_vars),n_permutations))# Collect data for each candidate and the corresponding conditioning set.
- forcandidateinsource_vars:
+ foridx_c,candidateinenumerate(source_vars):temp_cond=data.get_realisations(
- analysis_setup.current_value,
- set(source_vars).difference(set([candidate])))[0]
+ analysis_setup.current_value,
+ set(source_vars).difference(set([candidate])))[0]temp_cand=data.get_realisations(
- analysis_setup.current_value,[candidate])[0]
- # The following may happen if either the requested conditing is 'none'
- # or if the conditiong set that is tested consists only of a single
- # candidate.
+ analysis_setup.current_value,[candidate])[0]
+ # The following may happen if either the requested conditing is
+ # 'none' or if the conditiong set that is tested consists only of
+ # a single candidate.iftemp_condisNone:conditional_realisations=conditional_realisations_targetre_use=['var2','conditional']
@@ -816,6 +827,45 @@
Source code for idtxl.stats
i_1=i_2i_2+=data.n_realisations(analysis_setup.current_value)
+ # Generate surrogates for the current candidate.
+ if(analysis_setup._cmi_estimator.is_analytic_null_estimator()and
+ permute_in_time):
+ # Generate the surrogates analytically
+ surr_table[idx_c,:]=(
+ analysis_setup._cmi_estimator.estimate_surrogates_analytic(
+ n_perm=n_permutations,
+ var1=data.get_realisations(analysis_setup.current_value,
+ [candidate])[0],
+ var2=analysis_setup._current_value_realisations,
+ conditional=temp_cond))
+ else:
+ analysis_setup.settings['analytical_surrogates']=False
+ surr_candidate_realisations=_get_surrogates(
+ data,
+ analysis_setup.current_value,
+ [candidate],
+ n_permutations,
+ analysis_setup.settings)
+ try:
+ surr_table[idx_c,:]=(
+ analysis_setup._cmi_estimator.estimate_parallel(
+ n_chunks=n_permutations,
+ re_use=['var2','conditional'],
+ var1=surr_candidate_realisations,
+ var2=analysis_setup._current_value_realisations,
+ conditional=temp_cond))
+ exceptex.AlgorithmExhaustedErrorasaee:
+ # The aglorithm cannot continue here, so
+ # we'll terminate the max sequential stats test,
+ # and declare all not significant
+ print('AlgorithmExhaustedError encountered in estimations: {}.'.format(
+ aee.message))
+ print('Stopping sequential max stats at candidate with rank 0')
+ return \
+ (np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
+ np.ones(len(analysis_setup.selected_vars_sources)),
+ np.zeros(len(analysis_setup.selected_vars_sources)))
+
# Calculate original statistic (multivariate/bivariate TE/MI)try:individual_stat=analysis_setup._cmi_estimator.estimate_parallel(
@@ -829,52 +879,18 @@
Source code for idtxl.stats
# we'll terminate the max sequential stats test,# and declare all not significantprint('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
+ 'estimations: {}.'.format(aee.message))print('Stopping sequential max stats at candidate with rank 0')# For now we don't need a stack trace:# traceback.print_tb(aee.__traceback__)# Return (signficance, pvalue, TEs):
- return \
- (np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
+ return(
+ np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),np.ones(len(analysis_setup.selected_vars_sources)),np.zeros(len(analysis_setup.selected_vars_sources)))selected_vars_order=utils.argsort_descending(individual_stat)individual_stat_sorted=utils.sort_descending(individual_stat)
-
- # Don't re-use surrogate table from previous pruning using min stats
- # like for the multivariate algorithm. There is no longer a global
- # min_stats including all sources variables, but a separate table per
- # source.
- conditional_realisations_sources=data.get_realisations(
- analysis_setup.current_value,source_vars)[0]
- ifconditional_realisations_targetisNone:
- conditional_realisations=conditional_realisations_sources
- else:
- conditional_realisations=np.hstack((
- conditional_realisations_sources,
- conditional_realisations_target))
- try:
- surr_table=_create_surrogate_table(
- analysis_setup=analysis_setup,
- data=data,
- idx_test_set=analysis_setup.selected_vars_sources,
- n_perm=n_permutations,
- conditional=conditional_realisations)
- exceptex.AlgorithmExhaustedErrorasaee:
- # The algorithm cannot continue here, so
- # we'll terminate the max sequential stats test,
- # and declare all not significant
- print('AlgorithmExhaustedError encountered in '
- 'estimations: '+aee.message)
- print('Stopping sequential max stats at candidate with rank 0')
- # For now we don't need a stack trace:
- # traceback.print_tb(aee.__traceback__)
- # Return (signficance, pvalue, TEs):
- return \
- (np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
- np.ones(len(analysis_setup.selected_vars_sources)),
- np.zeros(len(analysis_setup.selected_vars_sources)))max_distribution=_sort_table_max(surr_table)# Compare each original value with the distribution of the same rank,
@@ -930,9 +946,8 @@
Source code for idtxl.stats
transfer entropy value to be tested conditional : numpy array [optional] realisations of conditional, 2D numpy array where array dimensions
- represent [realisations x variable dimension] (per default all
- already selected source and target variables from the
- analysis_setup are used)
+ represent [realisations x variable dimension] (default=None, no
+ conditioning performed) Returns: bool
@@ -944,7 +959,8 @@
Source code for idtxl.stats
Raises: ex.AlgorithmExhaustedError
- Raised from _create_surrogate_table() when calculation cannot be made
+ Raised from _create_surrogate_table() when calculation cannot be
+ made """# Set defaults and get parameters from settings dictionaryanalysis_setup.settings.setdefault('n_perm_min_stat',500)
@@ -1148,7 +1164,7 @@
Source code for idtxl.stats
i_1=0i_2=chunk_sizeifanalysis_setup.settings['verbose']:
- print('\nTesting unq information in s1')
+ print('\nTesting unq information in s1')forpinrange(n_perm):ifanalysis_setup.settings['verbose']:print('\tperm {0} of {1}'.format(p,n_perm))
@@ -1174,7 +1190,7 @@
Source code for idtxl.stats
i_1=0i_2=chunk_sizeifanalysis_setup.settings['verbose']:
- print('\nTesting unq information in s2')
+ print('\nTesting unq information in s2')forpinrange(n_perm):ifanalysis_setup.settings['verbose']:print('\tperm {0} of {1}'.format(p,n_perm))
@@ -1270,7 +1286,7 @@
Source code for idtxl.stats
i_1=0i_2=chunk_sizeifanalysis_setup.settings['verbose']:
- print('\nTesting shd and syn information in both sources')
+ print('\nTesting shd and syn information in both sources')forpinrange(n_perm):ifanalysis_setup.settings['verbose']:print('\tperm {0} of {1}'.format(p,n_perm))
@@ -1336,9 +1352,8 @@
Source code for idtxl.stats
number of permutations for testing conditional : numpy array [optional] realisations of conditional, 2D numpy array where array dimensions
- represent [realisations x variable dimension] (per default all
- already selected source and target variables from the
- analysis_setup are used)
+ represent [realisations x variable dimension] (default=None, no
+ conditioning performed) Returns: numpy array surrogate MI/CMI/TE values, dimensions: (length test set, number of
@@ -1351,22 +1366,11 @@
Source code for idtxl.stats
# Check which permutation type is requested by the calling function.permute_in_time=analysis_setup.settings['permute_in_time']
- # Check what type of conditioning is requested.
- ifconditionalisNone:
- conditional=analysis_setup._selected_vars_realisations
-
# Create surrogate table.
- # if analysis_setup.settings['verbose']:
- # print('\ncreating surrogate table with {0} permutations:'.format(
- # n_perm))
- # print('\tcand.', end='')surr_table=np.zeros((len(idx_test_set),n_perm))current_value_realisations=analysis_setup._current_value_realisationsidx_c=0forcandidateinidx_test_set:
- # if analysis_setup.settings['verbose']:
- # print('\t{0}'.format(analysis_setup._idx_to_lag([candidate])[0]),
- # end='')if(analysis_setup._cmi_estimator.is_analytic_null_estimator()andpermute_in_time):# Generate the surrogates analytically
@@ -1623,23 +1627,22 @@
# Adjust color and position of nodes (variables).pos=nx.spring_layout(graph)color=['lavender'forcinrange(graph.number_of_nodes())]
- for(ind,n)inenumerate(graph.node):
+ for(ind,n)inenumerate(graph.nodes):# Adjust posistions of nodes.ifn==current_value:
@@ -176,8 +176,7 @@
Source code for idtxl.visualise_graph
fig=plt.figure()nx.draw(graph,pos=pos,with_labels=True,font_weight='bold',
- node_size=900,alpha=0.7,node_shape='s',node_color=color,
- hold=True)
+ node_size=900,alpha=0.7,node_shape='s',node_color=color)# Optionally display edge labels showing the TE valueifdisplay_edge_labels:edge_labels=nx.get_edge_attributes(graph,'te')
@@ -196,7 +195,7 @@
returncbar
+def_plot_adj_matrix(adj_matrix,mat_color='gray_r',diverging=False,
+ cbar_label='delay',cbar_stepsize=1):
+ """Plot adjacency matrix."""
+ # Plot matrix, set minimum and maximum values to the same value for
+ # diverging plots to center colormap at 0, i.e., 0 is plotted in white
+ # https://stackoverflow.com/questions/25500541/
+ # matplotlib-bwr-colormap-always-centered-on-zero
+ ifdiverging:
+ max_val=np.max(abs(adj_matrix))
+ min_val=-max_val
+ else:
+ max_val=np.max(adj_matrix)
+ min_val=-np.min(adj_matrix)
+ plt.imshow(adj_matrix,cmap=mat_color,interpolation='nearest',
+ vmin=min_val,vmax=max_val)
+
+ # Set the colorbar and make colorbar match the image in size using the
+ # fraction and pad parameters (see https://stackoverflow.com/a/26720422).
+ ifcbar_label=='delay':
+ cbar_label='delay [samples]'
+ cbar_ticks=np.arange(0,max_val+1,cbar_stepsize)
+ else:
+ cbar_ticks=np.arange(min_val,max_val+0.01*max_val,
+ cbar_stepsize)
+ cbar=plt.colorbar(fraction=0.046,pad=0.04,ticks=cbar_ticks)
+ cbar.set_label(cbar_label,rotation=90)
+
+ # Set x- and y-ticks.
+ plt.xticks(np.arange(adj_matrix.shape[1]))
+ plt.yticks(np.arange(adj_matrix.shape[0]))
+ ax=plt.gca()
+ ax.xaxis.tick_top()
+ returncbar
+
+
[docs]defplot_mute_graph():"""Plot MuTE example network.
@@ -349,23 +383,22 @@