Skip to content

Commit

Permalink
fix sorting_index
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 authored Nov 3, 2023
1 parent 9423d3a commit ffa3286
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def set_analysis(self, analysis: SpikeAnalysis):
def plot_zscores(
self,
figsize: Optional[tuple] = (24, 10),
sorting_index: Optional[int] = None,
sorting_index: Optional[int] | list[int] = None,
z_bar: Optional[list[int]] = None,
indices: bool = False,
show_stim: bool = True,
Expand All @@ -95,7 +95,7 @@ def plot_zscores(
----------
figsize : Optional[tuple], optional
Matplotlib figsize tuple. For multiple trial groups bigger is better. The default is (24, 10).
sorting_index : Optional[int], optional
sorting_index : Optional[int] | list[int], optional
The trial group to sort all values on. The default is None (which uses the largest trial group).
z_bar: list[int]
If given a list with min z score for the cbar at index 0 and the max at index 1. Overrides cbar generation
Expand Down Expand Up @@ -132,7 +132,7 @@ def plot_zscores(
def plot_raw_firing(
self,
figsize: Optional[tuple] = (24, 10),
sorting_index: Optional[int] = None,
sorting_index: Optional[int] | list[int] = None,
bar: Optional[list[int]] = None,
indices: bool = False,
show_stim: bool = True,
Expand All @@ -148,7 +148,7 @@ def plot_raw_firing(
----------
figsize : Optional[tuple], optional
Matplotlib figsize tuple. For multiple trial groups bigger is better. The default is (24, 10).
sorting_index : Optional[int], optional
sorting_index : Optional[int] | list[int], optional
The trial group to sort all values on. The default is None (which uses the largest trial group).
bar: list[int]
If given a list with min firing rate for the cbar at index 0 and the max at index 1. Overrides cbar generation
Expand Down Expand Up @@ -182,7 +182,7 @@ def _plot_scores(
self,
data: str = "zscore",
figsize: Optional[tuple] = (24, 10),
sorting_index: Optional[int] = None,
sorting_index: Optional[int] | list[int] = None,
bar: Optional[list[int]] = None,
indices: bool = False,
show_stim: bool = True,
Expand Down Expand Up @@ -211,7 +211,7 @@ def _plot_scores(
if indices is True, the function will return the cluster ids as displayed in the z bar graph
"""

if data == "zscore":
z_scores = self.data.z_scores
elif data == "raw-data":
Expand All @@ -234,7 +234,7 @@ def _plot_scores(

stim_lengths = self._get_event_lengths()
sorted_cluster_ids = {}
for stimulus in z_scores.keys():
for stim_idx, stimulus in enumerate(z_scores.keys()):
if len(np.shape(z_scores)) < 3:
sub_zscores = np.expand_dims(z_scores[stimulus], axis=1)
sub_zscores = z_scores[stimulus]
Expand All @@ -259,9 +259,14 @@ def _plot_scores(

else:
RESET_INDEX = False
assert isinstance(sorting_index, (list,int)), "sorting_index must be list or int"
if isinstance(sorting_index, list):
current_sorting_index = sorting_index[stim_idx]
else:
current_sorting_index = sorting_index
event_window = np.logical_and(bins >= 0, bins <= length)

z_score_sorting_index = np.argsort(-np.sum(sub_zscores[:, sorting_index, event_window], axis=1))
z_score_sorting_index = np.argsort(-np.sum(sub_zscores[:, current_sorting_index, event_window], axis=1))
sorted_cluster_ids[stimulus] = self.data.cluster_ids[z_score_sorting_index]
sorted_z_scores = sub_zscores[z_score_sorting_index, :, :]

Expand Down

0 comments on commit ffa3286

Please sign in to comment.