diff --git a/docs/source/conf.py b/docs/source/conf.py index 5fd6a08..8fc1cae 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,7 +24,7 @@ author = "Zach McKenzie" # The full version, including alpha/beta/rc tags -release = "0.0.8" +release = "0.0.9" # -- General configuration --------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index b36d863..b5b3110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeanalysis" -version = '0.0.8' +version = '0.0.9' authors = [{name="Zach McKenzie", email="mineurs-torrent0x@icloud.com"}] description = "Analysis of Spike Trains" requires-python = ">=3.9" diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index beadec0..1022b92 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -66,7 +66,7 @@ def set_analysis(self, analysis: SpikeAnalysis): The SpikeAnalysis object for plotting """ - assert isinstance(analysis, SpikeAnalysis), "analysis must be a SpikeAnaysis dataset" + assert isinstance(analysis, SpikeAnalysis), "analysis must be a SpikeAnalysis dataset" self.data = analysis def plot_zscores( @@ -74,7 +74,8 @@ def plot_zscores( figsize: Optional[tuple] = (24, 10), sorting_index: Optional[int] = None, z_bar: Optional[list[int]] = None, - ): + indices: bool = False, + ) -> Optional[np.array]: """ Function to plot heatmaps of z scored firing rate. All trial groups are plotted on the same axes. So it is best to have a figsize that wide to fit all different trial groups. In this plot each @@ -89,6 +90,13 @@ def plot_zscores( The trial group to sort all values on. The default is None (which uses the largest trial group). z_bar: list[int] If given a list with min z score for the cbar at index 0 and the max at index 1. Overrides cbar generation + indices: bool, default False + If true will return the cluster ids sorted in the order they appear in the graph + + Returns + ------- + ordered_cluster_ids: np.array + if indices is True, the function will return the cluster ids as displayed in the z bar graph """ @@ -214,6 +222,9 @@ def plot_zscores( if RESET_INDEX: sorting_index = None + if indices: + return self.data.cluster_ids[z_score_sorting_index] + def plot_raster(self, window: Union[list, list[list]]): """ Function to plot rasters @@ -223,8 +234,6 @@ def plot_raster(self, window: Union[list, list[list]]): window : Union[list, list[list]] The window [start, stop] to plot the raster over. Either one global list or nested list of [start, stop] format - - """ from .analysis_utils import histogram_functions as hf