Skip to content

Commit

Permalink
Merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley authored and jasmainak committed May 19, 2021
1 parent d187a58 commit 79e92f1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
29 changes: 28 additions & 1 deletion hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .cells_default import pyramidal, basket
from .cell_response import CellResponse
from .params import _long_name, _short_name
from .viz import plot_cells, plot_cell_morphology
from .viz import plot_cells, plot_cell_morphology, plot_connectivity_matrix
from .externals.mne import _validate_type, _check_option


Expand Down Expand Up @@ -1030,6 +1030,7 @@ def add_connection(self, src_gids, target_gids, loc, receptor,
raise AssertionError(
'All target_gids must be of the same type')
conn['target_type'] = target_type
conn['target_range'] = self.gid_ranges[_long_name(target_type)]
conn['num_targets'] = len(target_set)

if len(target_gids) != len(src_gids):
Expand All @@ -1048,6 +1049,7 @@ def add_connection(self, src_gids, target_gids, loc, receptor,
raise AssertionError('All src_gids must be of the same type')
gid_pairs[src_gid] = target_src_pair
conn['src_type'] = src_type
conn['src_range'] = self.gid_ranges[_long_name(src_type)]
conn['num_srcs'] = len(src_gids)

conn['gid_pairs'] = gid_pairs
Expand Down Expand Up @@ -1143,6 +1145,10 @@ class _Connectivity(dict):
Number of unique source gids.
num_targets : int
Number of unique target gids.
src_range : range
Range of gids identified by src_type.
target_range : target_range
Range of gids identified by target_type.
loc : str
Location of synapse on target cell. Must be
'proximal', 'distal', or 'soma'. Note that inhibitory synapses
Expand All @@ -1163,6 +1169,11 @@ class _Connectivity(dict):
probability : float
Probability of connection between any src-target pair.
Defaults to 1.0 producing an all-to-all pattern.
Notes
-----
The len() of src_range or target_range may not match
num_srcs and num_targets for probability < 1.0.
"""

def __repr__(self):
Expand Down Expand Up @@ -1230,6 +1241,22 @@ def drop(self, probability):

self['probability'] = probability

def plot(self, ax=None, show=True):
"""Plot connectivity matrix for instance of _Connectivity object.
Parameters
----------
conn : Instance of _Connectivity object
The _Connectivity object
Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""

return plot_connectivity_matrix(self, ax=ax, show=show)


class _NetworkDrive(dict):
"""A class for containing the parameters of external drives
Expand Down
40 changes: 40 additions & 0 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,43 @@ def plot_cell_morphology(axes=None, cell_types=None, show=True):
plt.tight_layout()
plt_show(show)
return axes
def plot_connectivity_matrix(conn, ax=None, show=True):
"""Plot connectivity matrix for instance of _Connectivity object.
Parameters
----------
conn : Instance of _Connectivity object
The _Connectivity object
Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""
import matplotlib.pyplot as plt
from.network import _Connectivity

if not isinstance(conn, _Connectivity):
raise TypeError('conn must be instance of _Connectivity')
if ax is None:
_, ax = plt.subplots(1, 1)

src_range = np.array(conn['src_range'])
target_range = np.array(conn['target_range'])
connectivity_matrix = np.zeros((len(src_range), len(target_range)))
for src_gid, target_src_pair in conn['gid_pairs'].items():
src_idx = np.where(src_range == src_gid)[0][0]
target_indeces = np.in1d(target_range, target_src_pair)
connectivity_matrix[src_idx, :] = target_indeces

ax.imshow(connectivity_matrix, cmap='Greys', interpolation='none')
ax.set_xlabel(f'target gids ({target_range[0]}-{target_range[-1]})')
ax.set_xticklabels(list())
ax.set_ylabel(f'source gids ({src_range[0]}-{src_range[-1]})')
ax.set_yticklabels(list())
ax.set_title(f"{conn['src_type']} -> {conn['target_type']} "
f"({conn['loc']}, {conn['receptor']})")

plt_show(show)
return ax.get_figure()

0 comments on commit 79e92f1

Please sign in to comment.