From 0497e55fc994fc6bd1dfbcc84945bdbdfacd8f6f Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 11 Dec 2024 12:10:14 -0500 Subject: [PATCH] Change z-score (#57) * wip mask * bug fix * oops * improve check * another fix * more fixes * fix correlation bug * playing with latency values * wip exclusion * fix typo * add save function * wip analysis * Revert "wip analysis" This reverts commit 93c31a6e8cd4b0236ea2d60bb4bf8c90359ca225. * another plotter tweak * wip * fix * fix trials * try again * oops * fix mean * fix mean and dtype * wip * wip * updates and add latency filter * black * wip * wip2 * fix test for new way to z score data * add test for baseline * add set_mask test * test auc filter --- .../analysis_utils/histogram_functions.py | 4 +- .../analysis_utils/latency_functions.py | 2 +- src/spikeanalysis/curated_spike_analysis.py | 101 +++++++++++++-- src/spikeanalysis/plotbase.py | 32 ++++- src/spikeanalysis/spike_analysis.py | 119 ++++++++++++++++-- src/spikeanalysis/spike_plotter.py | 91 ++++++++++---- src/spikeanalysis/utils.py | 3 +- test/histogram_test.py | 8 +- test/test_curated_spike_analysis.py | 20 +++ test/test_spike_analysis.py | 27 ++++ 10 files changed, 357 insertions(+), 50 deletions(-) diff --git a/src/spikeanalysis/analysis_utils/histogram_functions.py b/src/spikeanalysis/analysis_utils/histogram_functions.py index 7a315f7..99bf284 100644 --- a/src/spikeanalysis/analysis_utils/histogram_functions.py +++ b/src/spikeanalysis/analysis_utils/histogram_functions.py @@ -187,10 +187,10 @@ def convert_bins(bins: np.array, bin_number: np.int32) -> np.array: @jit(nopython=True, cache=True) -def z_score_values(z_trial: numba.float32[:, :, :], mean_fr: numba.float32[:], std_fr: numba.float32[:]) -> np.array: +def z_score_values(z_trial: numba.float32[:, :, :], mean_fr: numba.float32[:,:], std_fr: numba.float32[:,:]) -> np.array: z_trials = np.zeros(np.shape(z_trial)) for idx in range(len(mean_fr)): for idy in range(np.shape(z_trial)[1]): - z_trials[idx, idy, :] = (z_trial[idx, idy] - mean_fr[idx]) / std_fr[idx] + z_trials[idx, idy, :] = (z_trial[idx, idy] - mean_fr[idx, idy]) / std_fr[idx, idy] return z_trials diff --git a/src/spikeanalysis/analysis_utils/latency_functions.py b/src/spikeanalysis/analysis_utils/latency_functions.py index d0bbf2c..7179276 100644 --- a/src/spikeanalysis/analysis_utils/latency_functions.py +++ b/src/spikeanalysis/analysis_utils/latency_functions.py @@ -31,7 +31,7 @@ def latency_core_stats(bsl_fr: float, firing_data: np.array, time_bin_size: floa ) if final_prob <= 10e-6: break - elif n_bin * time_bin_size >= 0.400: # past 400 ms is not really a true latency + elif n_bin * time_bin_size >= 5: # past 400 ms is not really a true latency n_bin = np.shape(firing_data)[1] - 1 break diff --git a/src/spikeanalysis/curated_spike_analysis.py b/src/spikeanalysis/curated_spike_analysis.py index dc29ada..d388fb4 100644 --- a/src/spikeanalysis/curated_spike_analysis.py +++ b/src/spikeanalysis/curated_spike_analysis.py @@ -48,7 +48,6 @@ class CuratedSpikeAnalysis(SpikeAnalysis): def __init__( self, curation: dict | None = None, st: SpikeAnalysis | None = None, save_parameters=False, verbose=False ): - """ Parameters ---------- @@ -61,6 +60,7 @@ def __init__( super().__init__(save_parameters=save_parameters, verbose=verbose) if st is not None: self.set_spike_analysis(st=st) + self.mask = None def set_curation( self, @@ -93,7 +93,6 @@ def set_spike_data(self, sp: SpikeData): super().set_spike_data(sp=sp) self._original_cluster_ids = deepcopy(self.cluster_ids) - def set_spike_data_si(self, sp: "Sorting"): """ Function for setting a spikeinterface sorting @@ -125,6 +124,15 @@ def set_spike_analysis(self, st: SpikeAnalysis): self.cluster_ids = st.cluster_ids self.si_units = st.si_units + def set_mask(self, mask: list[bool]): + + if len(mask) != len(self.cluster_ids): + raise ValueError( + f"mask len {len(mask)} must be same as cluster ids len {len(self.cluster_ids)}. Maybe you need to revert curation." + ) + + self.mask = mask + def curate( self, criteria: str | dict, @@ -132,6 +140,7 @@ def curate( by_response: bool = False, by_trial: Literal["all"] | bool = False, trial_index: Optional[int] = None, + apply_mask: bool = False, ): """Function for loading the current curation Parameters @@ -144,9 +153,11 @@ def curate( Whether to analyze data by a particular response profile by_trial Literal['all'] | bool, default: False ***** - trial_index: Optional[int], default: None + trial_index: Optional[int | np.array], default: None Must be given if by_trial=True, to indicate which specific trial to be used + apply_mask: bool, default: False + If an additional mask is desired. If mask has not been set then this argument does nothing. """ curation = self.curation if len(curation) == 0: @@ -171,20 +182,21 @@ def curate( if len(sub_curation.shape) == 1: sub_curation = np.expand_dims(sub_curation, axis=1) mask = np.all(sub_curation, axis=1) - self.cluster_ids = self.cluster_ids[mask] else: assert trial_index is not None, "must give the trial index to look at only the trial" if len(sub_curation.shape) == 1: sub_curation = np.expand_dims(sub_curation, axis=1) - mask = sub_curation[:, trial_index] - self.cluster_ids = self.cluster_ids[mask] + if isinstance(trial_index, (int, float)): + mask = sub_curation[:, trial_index] + else: + mask = np.all(sub_curation[:, np.array(trial_index)], axis=1) + else: if len(sub_curation.shape) == 1: sub_curation = np.expand_dims(sub_curation, axis=1) mask = np.any(sub_curation, axis=1) - self.cluster_ids = self.cluster_ids[mask] elif by_stim: assert isinstance(criteria, str), "must give single stim" @@ -207,8 +219,6 @@ def curate( else: mask = np.any(mask_array, axis=1) - self.cluster_ids = self.cluster_ids[mask] - elif by_response: assert isinstance(criteria, str), "must give single response" @@ -230,11 +240,80 @@ def curate( else: mask = np.any(mask_array, axis=1) - self.cluster_ids = self.cluster_ids[mask] - else: raise Exception("must be by_stim, by_response, or both") + if self.mask is not None and apply_mask: + mask = np.logical_and(mask, self.mask) + + self.cluster_ids = self.cluster_ids[mask] + def revert_curation(self): """Function to revert to the pre-curated state""" self.cluster_ids = self._original_cluster_ids + + def filter_mask( + self, + window, + datatype="zscore", + filter="auc", + filter_params=None, + ): + + if filter == "auc": + if filter_params is None: + filter_params = {"all": dict(min=-50, max=50)} + else: + assert all(['min' in value for value in filter_params.values()]) + assert all(['max' in value for value in filter_params.values()]) + operator = np.nansum + else: + raise ValueError("only auc is implemented") + + if datatype == "zscore": + + data = self.z_scores + bins = self.z_bins + else: + data = self.mean_firing_rate + bins = self.fr_bins + + if isinstance(window, list): + window_is_list = True + if isinstance(window[0], list): + assert len(window) == len(data.keys()) + else: + assert len(window) == 2, "only give start stop" + window = [window for _ in range(len(data.keys()))] + elif isinstance(window, dict): + window_is_list = False + assert len(window.keys()) == len(data.keys()), "for dict please give one list of stims per stim" + + mask = np.ones((len(self.cluster_ids))) + for stim_index, (stim, scores) in enumerate(data.items()): + + if "all" in filter_params.keys(): + current_params = filter_params["all"] + else: + current_params = filter_params[stim] + + if window_is_list: + current_window = window[stim_index] + else: + current_window = window[stim] + + current_bins = bins[stim] + bin_window = np.logical_and(current_window[0] <= current_bins, current_bins <= current_window[1]) + + final_scores = scores[:, :, bin_window] + + final_scores_summed = operator(final_scores, axis=2) + + final_scores_masked = np.logical_or( + np.any(final_scores_summed > current_params["max"], axis=1), + np.any(final_scores_summed < current_params["min"], axis=1), + ) + mask = np.logical_and(mask, final_scores_masked) + + self.mask = mask + diff --git a/src/spikeanalysis/plotbase.py b/src/spikeanalysis/plotbase.py index 51f0d78..8b357df 100644 --- a/src/spikeanalysis/plotbase.py +++ b/src/spikeanalysis/plotbase.py @@ -19,6 +19,9 @@ "fontname": "The font to use", "fontstyle": "The style to use for the font", "fontsize": "The size of the text", + "save": "Whether to save images", + "format": "The format to save the image", + "extra_title": "Additional info to add to image title", } @@ -85,6 +88,10 @@ def _convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple: x_axis = plot_kwargs.pop("x_axis", self.x_axis) y_axis = plot_kwargs.pop("y_axis", self.y_axis) + save = plot_kwargs.pop("save", False) + format = plot_kwargs.pop("format", "png") + extra_title = plot_kwargs.pop("extra_title", "") + PlotKwargs = namedtuple( "PlotKwargs", [ @@ -99,10 +106,28 @@ def _convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple: "fontname", "fontstyle", "fontsize", + "save", + "format", + "extra_title", ], ) - plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis, fontname, fontstyle, fontsize) + plot_kwargs = PlotKwargs( + figsize, + dpi, + x_lim, + y_lim, + title, + cmap, + x_axis, + y_axis, + fontname, + fontstyle, + fontsize, + save, + format, + extra_title, + ) return plot_kwargs @@ -116,3 +141,8 @@ def set_plot_kwargs(self, ax: plt.axes, plot_kwargs: namedtuple): if plot_kwargs.ylim is not None: ax.set_ylim(plot_kwargs.ylim) + + def _save_fig(self, cluster_number, extra_title="", format="png"): + + title = f"{cluster_number}_{extra_title}" + plt.savefig(title, format=format) diff --git a/src/spikeanalysis/spike_analysis.py b/src/spikeanalysis/spike_analysis.py index b3c34a9..f6fec26 100644 --- a/src/spikeanalysis/spike_analysis.py +++ b/src/spikeanalysis/spike_analysis.py @@ -414,6 +414,10 @@ def get_raw_firing_rate( self.fr_bins[stim] = bins[fr_window_values] self.mean_firing_rate = final_fr + def zscore_data(self, time_bin_ms, bsl_window, z_window, eps:float=0): + + self.z_score_data(time_bin_ms=time_bin_ms, bsl_window=bsl_window, z_window=z_window, eps=eps) + def z_score_data( self, time_bin_ms: Union[list[float], float], @@ -468,6 +472,8 @@ def z_score_data( self.z_windows = {} self.z_bins = {} self.raw_zscores = {} + self.keep_trials = {} + self.raw_baselines = {} for idx, stim in enumerate(self.psths.keys()): if self._verbose: print(stim) @@ -496,16 +502,61 @@ def z_score_data( z_psth = psth[:, :, z_window_values] z_scores[stim] = np.zeros(np.shape(z_psth)) self.raw_zscores[stim] = np.zeros(np.shape(z_psth)) + self.keep_trials[stim] = {} + final_z_scores[stim] = np.zeros((np.shape(z_psth)[0], len(trial_set), np.shape(z_psth)[2])) + + # use median instead for determining good trials + bsl_mean_global = np.median( + np.sum(bsl_psth, axis=2) / (bsl_current[1] - bsl_current[0]), axis=1 + ) # test median + + bsl_std_global = ( + np.median( + np.abs(np.sum(bsl_psth, axis=2) / (bsl_current[1] - bsl_current[0]) - bsl_mean_global[:, None]), + axis=1, + ) + / 0.6744897501960817 + ) + + self.raw_baselines[stim] = np.sum(bsl_psth, axis=2) / (bsl_current[1] - bsl_current[0]) + # to get baseline firing we do a per trial baseline for the neuron. To get an estimate + # we divide the baseline into 3 periods and iterate through those chunks of data to get + # the sub firing rate. Then we average those. + n_chunks = sum(bsl_values) // 3 for trial_number, trial in enumerate(tqdm(trial_set)): + self.keep_trials[stim][trial] = np.zeros((z_psth.shape[0], sum(trials == trial)), dtype=bool) bsl_trial = bsl_psth[:, trials == trial, :] - mean_fr = np.mean(np.sum(bsl_trial, axis=2), axis=1) / ((bsl_current[1] - bsl_current[0])) + bsl_chunks = [] + # iterate over baseline chunks and do sum to get point firing rate + for bsl_chunk_index in range(3): + bsl_chunk = bsl_trial[:, :, (bsl_chunk_index * n_chunks) : (bsl_chunk_index + 1) * n_chunks] + # neuron x trial x value + bsl_chunk_sum = np.sum(bsl_chunk, axis=2) / ((bsl_current[1] - bsl_current[0]) / 3) + bsl_chunks.append(bsl_chunk_sum) + + # stack chunks in order to take the mean of the chunks + bsl_chunks = np.stack(bsl_chunks, axis=1) + mean_fr = np.mean(bsl_chunks, axis=1) # for future computations may be beneficial to have small eps to std to prevent divide by 0 - std_fr = np.std(np.sum(bsl_trial, axis=2), axis=1) / ((bsl_current[1] - bsl_current[0])) + eps + std_fr = np.std(bsl_chunks, axis=1) + eps + z_trial = z_psth[:, trials == trial, :] / time_bin_current z_trials = hf.z_score_values(z_trial, mean_fr, std_fr) z_scores[stim][:, trials == trial, :] = z_trials[:, :, :] - final_z_scores[stim][:, trial_number, :] = np.nanmean(z_trials, axis=1) + # if we are > 3 mads away from the tg mean then we eliminate a trial. + for neuron_bsl_idx in range(bsl_mean_global.shape[0]): + keep_trials = np.logical_and( + mean_fr[neuron_bsl_idx] + < (bsl_mean_global[neuron_bsl_idx] + (3 * bsl_std_global[neuron_bsl_idx])), + mean_fr[neuron_bsl_idx] + > (bsl_mean_global[neuron_bsl_idx] - (3 * bsl_std_global[neuron_bsl_idx])), + ) + final_z_scores[stim][neuron_bsl_idx, trial_number, :] = np.nanmean( + z_trials[neuron_bsl_idx, keep_trials, :], axis=0 + ) + + self.keep_trials[stim][trial][neuron_bsl_idx, :] = keep_trials self.raw_zscores[stim][:, trials == trial, :] = z_trials[:, :, :] self.z_bins[stim] = bins[z_window_values] self.z_scores = final_z_scores @@ -734,6 +785,7 @@ def trial_correlation( data = getattr(self, "psths") elif dataset == "raw": data = getattr(self, "raw_firing_rate") + bins = self.fr_bins elif dataset == "z_scores": data = getattr(self, "raw_zscores") bins = self.z_bins @@ -751,13 +803,14 @@ def trial_correlation( number of bins is{len(time_bin_ms)} and should be {self._total_stim}" time_bin_size = np.array(time_bin_ms) / 1000 - try: - stim_dict = self._get_key_for_stim() - except AttributeError: - pass else: time_bin_size = [None] * self._total_stim + try: + stim_dict = self._get_key_for_stim() + except AttributeError: + pass + correlations = {} for idx, stimulus in enumerate(data.keys()): trial_groups = np.array(self.events[stim_dict[stimulus]]["trial_groups"]) @@ -830,8 +883,39 @@ def autocorrelogram(self, time_ms: float = 500): self.acg = acg + def calculate_baseline_values(self, mode: str = "mean"): + + if not hasattr(self, "raw_baselines"): + raise ValueError("must run zscore_data in order to collect trial baselines") + + if mode == "mean": + func = np.mean + elif mode == "median": + func = np.median + elif mode == "max": + func = np.max + elif callable(mode): + func = mode + else: + raise ValueError("enter a function or one of ['mean', 'median', 'max']") + + baselines = {} + for stim, baseline in self.raw_baselines.items(): + baselines[stim] = func(baseline, axis=1) + + self.baselines = baselines + def return_value(self, value: str): - _values = ("z_scores", "raw_zscores", "mean_firing_rate", "raw_firing_rate", "correlations", "latency", "psths") + _values = ( + "z_scores", + "raw_zscores", + "mean_firing_rate", + "raw_firing_rate", + "correlations", + "latency", + "psths", + "baselines", + ) if hasattr(self, value): return getattr(self, value) @@ -889,7 +973,7 @@ def save_z_parameters(self, z_parameters: dict, overwrite: bool = False): with open(self._file_path / "z_parameters.json", "w") as write_file: json.dump(z_parameters, write_file) - def get_responsive_neurons(self, z_parameters: Optional[dict] = None): + def get_responsive_neurons(self, z_parameters: Optional[dict] = None, latency_threshold_ms: Optional[dict] = None): """ function for assessing only responsive neurons based on z scored parameters. @@ -934,12 +1018,16 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None): else: same_params = False + if latency_threshold_ms is None: + latency_threshold_ms = {k: None for k in self.z_scores.keys()} + self.responsive_neurons = {} for stim in self.z_scores.keys(): self.responsive_neurons[stim] = {} bins = self.z_bins[stim] current_z_scores = self.z_scores[stim] + current_latency_threshold = latency_threshold_ms[stim] if same_params: current_z_params = z_parameters["all"] else: @@ -961,13 +1049,26 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None): f"Not implemented for window of size {len(current_window)} possible lengths are 2 or 4" ) + current_bin_size = bins[1] - bins[0] # likely in ms + if current_latency_threshold is not None: + bins_to_threshold = int(current_latency_threshold // current_bin_size) + else: + bins_to_threshold = -1 + current_z_scores_sub = current_z_scores[:, :, window_index] + bin_threshold_z_score = current_z_scores_sub[:, :, :bins_to_threshold] + + # final should not be any. Need to think about how we want latency to be incorporated.... if current_score > 0 or "inhib" not in key.lower(): z_above_threshold = np.sum(np.where(current_z_scores_sub > current_score, 1, 0), axis=2) + latency_resp_neurons = np.any(np.where(bin_threshold_z_score > current_score, True, False), axis=2) else: z_above_threshold = np.sum(np.where(current_z_scores_sub < current_score, 1, 0), axis=2) + latency_resp_neurons = np.any(np.where(bin_threshold_z_score < current_score, True, False), axis=2) responsive_neurons = np.where(z_above_threshold > current_n_bins, True, False) + + responsive_neurons = np.logical_and(responsive_neurons, latency_resp_neurons) self.responsive_neurons[stim][key] = responsive_neurons def save_responsive_neurons(self, overwrite: bool = False): diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index cfacb9e..95ad027 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -83,6 +83,7 @@ def plot_zscores( z_bar: Optional[list[int]] = None, indices: bool = False, show_stim: bool = True, + exclusion_dict: dict = None, plot_kwargs: dict = {}, ) -> Optional[np.array]: """ @@ -112,14 +113,15 @@ def plot_zscores( if indices is True, the function will return the cluster ids as displayed in the z bar graph """ - reset = False if self.cmap is None: - reset = True - try: - import seaborn - self.cmap = "vlag" - except ImportError: - self.cmap = "bwr" + if plot_kwargs.get("cmap", None) is None: + try: + import seaborn + + cmap = "vlag" + except ImportError: + cmap = "bwr" + plot_kwargs["cmap"] = cmap sorted_cluster_ids = self._plot_scores( data="zscore", @@ -128,10 +130,9 @@ def plot_zscores( bar=z_bar, indices=indices, show_stim=show_stim, + exclusion_dict=exclusion_dict, plot_kwargs=plot_kwargs, ) - if reset: - self.cmap = None if indices: return sorted_cluster_ids @@ -143,6 +144,7 @@ def plot_raw_firing( bar: Optional[list[int]] = None, indices: bool = False, show_stim: bool = True, + exclusion_dict: dict | None = None, plot_kwargs: dict = {}, ) -> Optional[np.array]: """ @@ -174,10 +176,9 @@ def plot_raw_firing( if indices is True, the function will return the cluster ids as displayed in the z bar graph """ - reset = False if self.cmap is None: - reset = True - self.cmap = "viridis" + if plot_kwargs.get("cmap", None) is None: + plot_kwargs["cmap"] = "viridis" sorted_cluster_ids = self._plot_scores( data="raw-data", @@ -186,12 +187,10 @@ def plot_raw_firing( bar=bar, indices=indices, show_stim=show_stim, + exclusion_dict=exclusion_dict, plot_kwargs=plot_kwargs, ) - if reset: - self.cmap = None - if indices: return sorted_cluster_ids @@ -203,6 +202,7 @@ def _plot_scores( bar: Optional[list[int]] = None, indices: bool = False, show_stim: bool = True, + exclusion_dict: None | dict = None, plot_kwargs: dict = {}, ) -> Optional[np.array]: """ @@ -220,7 +220,7 @@ def _plot_scores( If given a list with min for the cbar at index 0 and the max at index 1. Overrides cbar generation indices: bool, default False If true will return the cluster ids sorted in the order they appear in the graph as a dict of stimuli - show_stim: bool, default True + show_stim: bool |int, default True Show lines where stim onset and offset are plot_kwargs: dict default: {} matplot lib kwargs to overide the global kwargs for just the function @@ -287,18 +287,26 @@ def _plot_scores( if sorting_index is None: current_sorting_index = np.shape(sub_zscores)[1] - 1 reset_index = True + is_the_sorting_index_the_filter = False else: reset_index = False assert isinstance(sorting_index, (list, int)), "sorting_index must be list or int" if isinstance(sorting_index, list): + is_the_sorting_index_the_filter = len(sorting_index) == sub_zscores.shape[0] current_sorting_index = sorting_index[stim_idx] else: + is_the_sorting_index_the_filter = False current_sorting_index = sorting_index event_window = np.logical_and(bins >= 0, bins <= lengths[current_sorting_index]) + if is_the_sorting_index_the_filter: + z_score_sorting_index = sorting_index + else: + z_score_sorting_index = np.argsort( + -np.nansum(sub_zscores[:, current_sorting_index, event_window], axis=1) + ) - z_score_sorting_index = np.argsort(-np.sum(sub_zscores[:, current_sorting_index, event_window], axis=1)) if indices: if len(self.data.si_units) > 0: sorted_cluster_ids[stimulus] = {} @@ -308,15 +316,45 @@ def _plot_scores( ] else: sorted_cluster_ids[stimulus] = self.data.cluster_ids[z_score_sorting_index] + sorted_z_scores = sub_zscores[z_score_sorting_index, :, :] if len(np.shape(sorted_z_scores)) == 2: sorted_z_scores = np.expand_dims(sorted_z_scores, axis=1) - nan_mask = np.all( - np.all(np.isnan(sorted_z_scores) | np.equal(sorted_z_scores, 0) | np.isinf(sorted_z_scores), axis=2), + nan_mask = np.any( + np.any(np.isnan(sorted_z_scores) | np.isinf(sorted_z_scores), axis=2) + | np.all(np.equal(sorted_z_scores, 0), axis=2), axis=1, ) + + # exclusion_dict = {stim: {type: any/all, index: None/array}} + if exclusion_dict is not None: + + current_exclusion = exclusion_dict[stimulus] + any_values = current_exclusion["type"] == "any" + + if current_exclusion["index"] is None: + masked_indices = np.arange(0, sorted_z_scores.shape[1], 1, dtype=int) + else: + masked_indices = np.asarray(current_exclusion["index"]) + + if any_values: + mask_func = np.any + else: + mask_func = np.all + + extra_mask = mask_func( + np.all( + np.isnan(sorted_z_scores[:, masked_indices, :]) + | np.equal(sorted_z_scores[:, masked_indices, :], 0) + | np.isinf(sorted_z_scores[:, masked_indices, :]), + axis=2, + ), + axis=1, + ) + nan_mask = np.logical_or(nan_mask, extra_mask) + sorted_z_scores = sorted_z_scores[~nan_mask] if bar is not None: @@ -351,6 +389,8 @@ def _plot_scores( if idx == 0: sub_ax.set_ylabel(y_axis, fontsize="small") if show_stim: + if isinstance(show_stim, bool): + show_stim = 0.5 end_point = np.where((bins > lengths[idx] - bin_size) & (bins < lengths[idx] + bin_size))[0][ 0 ] # aim for nearest bin at end of stim @@ -360,7 +400,7 @@ def _plot_scores( np.shape(sorted_z_scores)[0], color="black", linestyle=":", - linewidth=0.5, + linewidth=show_stim, ) sub_ax.axvline( end_point, @@ -368,7 +408,7 @@ def _plot_scores( np.shape(sorted_z_scores)[0], color="black", linestyle=":", - linewidth=0.5, + linewidth=show_stim, ) self._despine(sub_ax) sub_ax.spines["bottom"].set_visible(False) @@ -403,6 +443,10 @@ def _plot_scores( fontname=plot_kwargs.fontname, ) plt.figure(dpi=plot_kwargs.dpi) + if plot_kwargs.save and plot_kwargs.title is not None: + self._save_fig(plot_kwargs.title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) + elif plot_kwargs.title is None: + print("give title to save heat map") plt.show() if reset_index: @@ -558,6 +602,8 @@ def plot_raster( fontname=plot_kwargs.fontname, ) plt.figure(dpi=plot_kwargs.dpi) + if plot_kwargs.save: + self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) plt.show() def plot_sm_fr( @@ -738,6 +784,9 @@ def plot_sm_fr( fontname=plot_kwargs.fontname, ) plt.figure(dpi=plot_kwargs.dpi) + + if plot_kwargs.save: + self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) plt.show() def plot_zscores_ind(self, z_bar: Optional[list[int]] = None, show_stim: bool = True): diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index 6c4fd85..cbed2da 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -208,12 +208,13 @@ def prevalence_counts( prevalence_dict = {} rt_0 = responsive_neurons[stim[0]][list(responsive_neurons[stim[0]].keys())[0]] - n_tgs = rt_0.shape[1] + n_tgs_dict = {st: rt[list(rt.keys())[0]].shape[1] for st, rt in responsive_neurons.items()} n_neurons = rt_0.shape[0] if by_trialgroup: for st in stim: prevalence_dict[st] = {} response_types = responsive_neurons[st] + n_tgs = n_tgs_dict[st] for n_trial in range(n_tgs): response_list = [] response_labels = [] diff --git a/test/histogram_test.py b/test/histogram_test.py index e1e16fb..cf34e82 100644 --- a/test/histogram_test.py +++ b/test/histogram_test.py @@ -209,8 +209,8 @@ def test_z_score_values(): z_trial = np.ones((3, 4, 10), dtype=np.float32) z_trial[0, 0, 9] = 10 z_trial[2, 2, 2] = -5 - mean_fr = np.zeros(3, dtype=np.float32) - std_fr = np.ones(3, dtype=np.float32) + mean_fr = np.zeros((3,4), dtype=np.float32) + std_fr = np.ones((3,4), dtype=np.float32) z_trials = hf.z_score_values(z_trial, mean_fr, std_fr) @@ -219,8 +219,8 @@ def test_z_score_values(): assert z_trials[0, 0, 9] == 10 assert z_trials[2, 2, 2] == -5 - mean_fr2 = np.ones(3, dtype=np.float32) - std_fr2 = 0.5 * np.ones(3, dtype=np.float32) + mean_fr2 = np.ones((3,4), dtype=np.float32) + std_fr2 = 0.5 * np.ones((3,4), dtype=np.float32) z_trials_2 = hf.z_score_values(z_trial, mean_fr2, std_fr2) assert z_trials_2[0, 0, 9] == 18 diff --git a/test/test_curated_spike_analysis.py b/test/test_curated_spike_analysis.py index 639950e..5e5c9f9 100644 --- a/test/test_curated_spike_analysis.py +++ b/test/test_curated_spike_analysis.py @@ -83,3 +83,23 @@ def test_curation_both_trial(csa): def test_curation_wrong_value(csa): with pytest.raises(Exception): csa.curate(criteria="test", by_stim=False, by_respone=False, by_trial=False) + +def test_set_mask(csa): + csa.revert_curation() + csa.set_mask([True, True]) + + with pytest.raises(ValueError): + csa.set_mask([True, True, True]) + + +def test_auc_filter(csa): + z_scores = {'test' :np.vstack((np.ones((1,2,100)), 2*np.ones((1,2,100))))} + z_bins = {'test': np.linspace(0,100,100)} + csa.z_scores = z_scores + csa.z_bins = z_bins + + csa.filter_mask(window=[20,40], filter_params={'test': {'min':-40, 'max': 30}}) + + assert sum(csa.mask) ==1 + csa.filter_mask(window=[20,40], filter_params={'test': {'min':-40, 'max': 80}}) + assert sum(csa.mask) == 0 \ No newline at end of file diff --git a/test/test_spike_analysis.py b/test/test_spike_analysis.py index f03ce7c..fb828bf 100644 --- a/test/test_spike_analysis.py +++ b/test/test_spike_analysis.py @@ -430,3 +430,30 @@ def test_autocorrelogram(sa): assert np.shape(sa.acg) == (2, 24) nptest.assert_array_equal(sa.acg[0], np.zeros((24,))) + +def test_baselines(sa): + + sa.events = { + "0": { + "events": np.array([100, 200]), + "lengths": np.array([100, 100]), + "trial_groups": np.array([1, 1]), + "stim": "test", + } + } + sa.get_raw_psth(window=[0, 300], time_bin_ms=50) + + psths = sa.psths + psths["test"]["psth"][0, 0, 0:200] = 1 + psths["test"]["psth"][0, 1, 100:300] = 2 + psths["test"]["psth"][0, 0, 3000:4000] = 5 + sa.psths = psths + # print(f"PSTH {sa.psths}") + sa.zscore_data(time_bin_ms=1000, bsl_window=[0, 50], z_window=[0, 300]) + + for func in ['mean', 'median', 'max']: + sa.calculate_baseline_values(func) + + bsls = sa.baselines + assert isinstance(bsls, dict) + nptest.assert_allclose(bsls['test'],np.array([8.04,.1])) \ No newline at end of file