Skip to content

Commit

Permalink
update plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Oct 3, 2023
1 parent e7c5529 commit 95dc72f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
2 changes: 1 addition & 1 deletion docs/source/submodules/intrinsic_plotter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ lamina of the spinal cord has most of the units found during sorting.

.. code-block:: python
iplotter.plot_spike_dpeth_fr(sp=spikes) # spikes is SpikeData
iplotter.plot_spike_depth_fr(sp=spikes) # spikes is SpikeData
CDF
---
Expand Down
85 changes: 45 additions & 40 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def plot_zscores(
Returns
-------
ordered_cluster_ids: np.array
sorted_cluster_ids: np.array
if indices is True, the function will return the cluster ids as displayed in the z bar graph
"""
if self.cmap == "viridis":
if self.cmap is None:
self.cmap = "vlag"

index_array = self._plot_scores(
sorted_cluster_ids = self._plot_scores(
data="zscore",
figsize=figsize,
sorting_index=sorting_index,
Expand All @@ -118,7 +118,7 @@ def plot_zscores(
show_stim=show_stim,
)
if indices:
return index_array
return sorted_cluster_ids

def plot_raw_firing(
self,
Expand Down Expand Up @@ -150,19 +150,19 @@ def plot_raw_firing(
Returns
-------
ordered_cluster_ids: Optional[np.array]
ordered_cluster_ids: Optional[dict]
if indices is True, the function will return the cluster ids as displayed in the z bar graph
"""
if self.cmap == "vlag":
if self.cmap is None:
self.cmap = "viridis"

index_array = self._plot_scores(
sorted_cluster_ids = self._plot_scores(
data="raw-data", figsize=figsize, sorting_index=sorting_index, bar=bar, indices=indices, show_stim=show_stim
)

if indices:
return index_array
return sorted_cluster_ids

def _plot_scores(
self,
Expand All @@ -187,13 +187,13 @@ def _plot_scores(
bar: list[int]
If given a list with min for the cbar at index 0 and the max at index 1. Overrides cbar generation
indices: bool, default False
If true will return the cluster ids sorted in the order they appear in the graph
If true will return the cluster ids sorted in the order they appear in the graph as a dict of stimuli
show_stim: bool, default True
Show lines where stim onset and offset are
Returns
-------
ordered_cluster_ids: Optional[np.array]
sorted_cluster_ids: Optional[dict]
if indices is True, the function will return the cluster ids as displayed in the z bar graph
"""
Expand All @@ -208,10 +208,7 @@ def _plot_scores(
if figsize is None:
figsize = self.figsize

if self.cmap is None:
cmap = "vlag"
else:
cmap = self.cmap
cmap = self.cmap

if self.y_axis is None:
y_axis = "Units"
Expand All @@ -222,7 +219,7 @@ def _plot_scores(
assert len(bar) == 2, f"Please give z_bar as [min, max], you entered {bar}"

stim_lengths = self._get_event_lengths()

sorted_cluster_ids = {}
for stimulus in z_scores.keys():
if len(np.shape(z_scores)) < 3:
sub_zscores = np.expand_dims(z_scores[stimulus], axis=1)
Expand Down Expand Up @@ -251,17 +248,18 @@ def _plot_scores(
event_window = np.logical_and(bins >= 0, bins <= length)

z_score_sorting_index = np.argsort(-np.sum(sub_zscores[:, sorting_index, event_window], axis=1))

sorted_cluster_ids[stimulus] = self.data.cluster_ids[z_score_sorting_index]
sorted_z_scores = sub_zscores[z_score_sorting_index, :, :]
# nan_mask = np.all(
# np.isnan(sorted_z_scores) | np.equal(sorted_z_scores, 0) | np.isinf(sorted_z_scores), axis=2
# )

# sorted_z_scores = sorted_z_scores[~nan_mask]

if len(np.shape(sorted_z_scores)) == 2:
sorted_z_scores = np.expand_dims(sorted_z_scores, axis=1)

nan_mask = np.all(
np.all(np.isnan(sorted_z_scores) | np.equal(sorted_z_scores, 0) | np.isinf(sorted_z_scores), axis=2),
axis=1,
)
sorted_z_scores = sorted_z_scores[~nan_mask]

if bar is not None:
vmax = bar[1]
vmin = bar[0]
Expand Down Expand Up @@ -320,7 +318,11 @@ def _plot_scores(
]
)
cax.spines["bottom"].set_visible(False)
plt.colorbar(im, cax=cax, label="Z scores") # Similar to fig.colorbar(im, cax = cax)
if data == "zscore":
cbar_label = "Z scores"
else:
cbar_label = "Raw Firing"
plt.colorbar(im, cax=cax, label=cbar_label) # Similar to fig.colorbar(im, cax = cax)
plt.title(f"{stimulus}")
plt.figure(dpi=self.dpi)
plt.show()
Expand All @@ -329,7 +331,7 @@ def _plot_scores(
sorting_index = None

if indices:
return self.data.cluster_ids[z_score_sorting_index]
return sorted_cluster_ids

def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
"""
Expand Down Expand Up @@ -556,7 +558,7 @@ def plot_sm_fr(
plt.figure(dpi=self.dpi)
plt.show()

def plot_zscores_ind(self, z_bar: Optional[list[int]] = None):
def plot_zscores_ind(self, z_bar: Optional[list[int]] = None, show_stim: bool = True):
"""
Function for plotting z scored heatmaps by trial group rather than all trial groups on the same set of axes. In
This function all data is ordered based on the most responsive unit/trial group. Rows can be different units
Expand All @@ -567,6 +569,8 @@ def plot_zscores_ind(self, z_bar: Optional[list[int]] = None):
----------
z_bar: list[int]
If given a list with min z score for the cbar at index 0 and the max at index 1. Overrides cbar generation
show_stim: bool, default: True
Whether to mark at the stim onset and offset
"""
try:
z_scores = self.data.z_scores
Expand Down Expand Up @@ -631,22 +635,23 @@ def plot_zscores_ind(self, z_bar: Optional[list[int]] = None):
ax.set_xticks([i * bins_length for i in range(7)])
ax.set_xticklabels([round(bins[i * bins_length], 4) if i < 7 else z_window[1] for i in range(7)])
ax.set_ylabel(y_axis, fontsize="small")
ax.axvline(
zero_point,
0,
np.shape(sorted_z_scores)[0],
color="black",
linestyle=":",
linewidth=0.5,
)
ax.axvline(
end_point,
0,
np.shape(sorted_z_scores)[0],
color="black",
linestyle=":",
linewidth=0.5,
)
if show_stim:
ax.axvline(
zero_point,
0,
np.shape(sorted_z_scores)[0],
color="black",
linestyle=":",
linewidth=0.5,
)
ax.axvline(
end_point,
0,
np.shape(sorted_z_scores)[0],
color="black",
linestyle=":",
linewidth=0.5,
)
self._despine(ax)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
Expand Down

0 comments on commit 95dc72f

Please sign in to comment.