diff --git a/hnn_core/network.py b/hnn_core/network.py index d7e6f34b8..75b57f8c2 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -16,7 +16,7 @@ from .cells_default import pyramidal, basket from .cell_response import CellResponse from .params import _long_name, _short_name -from .viz import plot_cells, plot_cell_morphology +from .viz import plot_cells, plot_cell_morphology, plot_connectivity_matrix from .externals.mne import _validate_type, _check_option @@ -1030,6 +1030,7 @@ def add_connection(self, src_gids, target_gids, loc, receptor, raise AssertionError( 'All target_gids must be of the same type') conn['target_type'] = target_type + conn['target_range'] = self.gid_ranges[_long_name(target_type)] conn['num_targets'] = len(target_set) if len(target_gids) != len(src_gids): @@ -1048,6 +1049,7 @@ def add_connection(self, src_gids, target_gids, loc, receptor, raise AssertionError('All src_gids must be of the same type') gid_pairs[src_gid] = target_src_pair conn['src_type'] = src_type + conn['src_range'] = self.gid_ranges[_long_name(src_type)] conn['num_srcs'] = len(src_gids) conn['gid_pairs'] = gid_pairs @@ -1143,6 +1145,10 @@ class _Connectivity(dict): Number of unique source gids. num_targets : int Number of unique target gids. + src_range : range + Range of gids identified by src_type. + target_range : target_range + Range of gids identified by target_type. loc : str Location of synapse on target cell. Must be 'proximal', 'distal', or 'soma'. Note that inhibitory synapses @@ -1163,6 +1169,11 @@ class _Connectivity(dict): probability : float Probability of connection between any src-target pair. Defaults to 1.0 producing an all-to-all pattern. + + Notes + ----- + The len() of src_range or target_range may not match + num_srcs and num_targets for probability < 1.0. """ def __repr__(self): @@ -1230,6 +1241,22 @@ def drop(self, probability): self['probability'] = probability + def plot(self, ax=None, show=True): + """Plot connectivity matrix for instance of _Connectivity object. + + Parameters + ---------- + conn : Instance of _Connectivity object + The _Connectivity object + + Returns + ------- + fig : instance of matplotlib Figure + The matplotlib figure handle. + """ + + return plot_connectivity_matrix(self, ax=ax, show=show) + class _NetworkDrive(dict): """A class for containing the parameters of external drives diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 099d0cc91..76553b50c 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -593,3 +593,43 @@ def plot_cell_morphology(axes=None, cell_types=None, show=True): plt.tight_layout() plt_show(show) return axes +def plot_connectivity_matrix(conn, ax=None, show=True): + """Plot connectivity matrix for instance of _Connectivity object. + + Parameters + ---------- + conn : Instance of _Connectivity object + The _Connectivity object + + Returns + ------- + fig : instance of matplotlib Figure + The matplotlib figure handle. + + """ + import matplotlib.pyplot as plt + from.network import _Connectivity + + if not isinstance(conn, _Connectivity): + raise TypeError('conn must be instance of _Connectivity') + if ax is None: + _, ax = plt.subplots(1, 1) + + src_range = np.array(conn['src_range']) + target_range = np.array(conn['target_range']) + connectivity_matrix = np.zeros((len(src_range), len(target_range))) + for src_gid, target_src_pair in conn['gid_pairs'].items(): + src_idx = np.where(src_range == src_gid)[0][0] + target_indeces = np.in1d(target_range, target_src_pair) + connectivity_matrix[src_idx, :] = target_indeces + + ax.imshow(connectivity_matrix, cmap='Greys', interpolation='none') + ax.set_xlabel(f'target gids ({target_range[0]}-{target_range[-1]})') + ax.set_xticklabels(list()) + ax.set_ylabel(f'source gids ({src_range[0]}-{src_range[-1]})') + ax.set_yticklabels(list()) + ax.set_title(f"{conn['src_type']} -> {conn['target_type']} " + f"({conn['loc']}, {conn['receptor']})") + + plt_show(show) + return ax.get_figure()