From bda3609553f27be3c12ec21a0f9e6964869306b9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 16 Oct 2023 14:20:00 -0400 Subject: [PATCH 01/25] WIP --- src/spikeanalysis/merged_spike_analysis.py | 134 +++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/spikeanalysis/merged_spike_analysis.py diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py new file mode 100644 index 0000000..e52deef --- /dev/null +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np + +from .stimulus_data import StimulusData + +from .spike_data import SpikeData + +from .spike_analysis import SpikeAnalysis +from .curated_spike_analysis import CuratedSpikeAnalysis + +@dataclass +class MergedSpikeAnalysis: + spikeanalysis_list: list + name_list: list | None + + def __post_init__(self): + if self.name_list is None: + pass + else: + assert len(self.spikeanalysis_list)==len(self.name_list), "each dataset needs a name value" + + def add_analysis(self, analysis: SpikeAnalysis|CuratedSpikeAnalysis, name: str|None): + self.spikeanalysis_list.append(analysis) + if self.name_list is not None: + assert len(self.spikeanalysis_list) == len(self.name_list)-1, 'must provide name if other datasets named' + self.name_list.append(name) + else: + print('other datasets were not given names ignoring naming') + + def merge(self, stim_name: str|None): + + # merge the cluster_ids for plotting + cluster_ids = [] + for idx, sa in enumerate(self.spikeanalysis_list): + if isinstance(self.name_list,list): + sub_cluster_ids = [str(self.name_list[idx])+str(cid) for cid in sa.cluster_ids] + else: + sub_cluster_ids = [str(idx)+ str(cid) for cid in sa.cluster_ids] + cluster_ids.append(sub_cluster_ids) + final_cluster_ids = [cid for cid in cluster_ids] + + self.cluster_ids = final_cluster_ids + + #merge the events for plotting + events = {} + if stim_name is not None: + events[stim_name] = self.spikeanalysis_list[0].events[stim_name] + else: + events = self.spikeanalysis_list[0].events + + + + + for idx, sa in enumerate(self.spikeanalysis_list): + + z_score_list = [] + try: + z_score_list.append(sa.z_scores) + + z_bins = sa.z_bins + z_windows = sa.z_windows + have_z_data = True + + except AttributeError: + if self.name_list is not None: + print(f'no z score data for data set {self.name_list[idx]}') + + fr_list = [] + try: + fr_list.append(sa.fr__) + + + except AttributeError: + if self.name_list is not None: + print(f'no raw firing rate data for data set {self.name_list[idx]}') + have_raw_data = False + + + + if len(z_score_list) >= 2: + z_scores = _merge(z_score_list, stim_name=stim_name) + if len(fr_list) >= 2: + fr_scores = _merge(fr_list, stim_name=stim_name) + + + + def _merge(dataset_list, stim_name): + + data_merge = {} + if stim_name is not None: + + + for stim in dataset_list[0].keys(): + data_merge[stim] = [] + for dataset in dataset_list: + data_merge[stim].append(dataset[stim]) + else: + data_merge[stim_name] = [] + for dataset in dataset_list: + data_merge[stim_name].append(dataset[stim_name]) + + for stim, array in data_merge.items(): + + data_merge[stim] = np.concatenate(array, axis=0) + + return data_merge + + + def get_merged_data(self): + + msa = MSA() + msa.cluster_ids = self.cluster_ids + + + + +class MSA(SpikeAnalysis): + + def get_raw_psth(self): + raise NotImplementedError + def z_score_data(self): + raise NotImplementedError + def latencies(self): + raise NotImplementedError + def set_spike_data(self): + print('data is immutable') + def set_stimulus_data(self): + print('data is immutable') + def get_raw_firing_rate(self): + raise NotImplementedError \ No newline at end of file From ac70fda2d68cc521ac6982056dcc69b256545b1b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:12:18 -0400 Subject: [PATCH 02/25] wip-- testing for z scores only --- src/spikeanalysis/merged_spike_analysis.py | 27 +++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index e52deef..1d1a623 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -1,13 +1,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Union -import numpy as np - -from .stimulus_data import StimulusData -from .spike_data import SpikeData +import numpy as np from .spike_analysis import SpikeAnalysis from .curated_spike_analysis import CuratedSpikeAnalysis @@ -53,8 +49,6 @@ def merge(self, stim_name: str|None): events = self.spikeanalysis_list[0].events - - for idx, sa in enumerate(self.spikeanalysis_list): z_score_list = [] @@ -83,6 +77,16 @@ def merge(self, stim_name: str|None): if len(z_score_list) >= 2: z_scores = _merge(z_score_list, stim_name=stim_name) + self.z_scores = z_scores + if stim_name is None: + self.z_bins = z_bins + self.z_windows = z_windows + else: + self.z_bins={} + self.z_bins[stim_name] = z_bins[stim_name] + self.z_windows={} + self.z_windows[stim_name] = z_windows[stim_name] + if len(fr_list) >= 2: fr_scores = _merge(fr_list, stim_name=stim_name) @@ -93,7 +97,6 @@ def _merge(dataset_list, stim_name): data_merge = {} if stim_name is not None: - for stim in dataset_list[0].keys(): data_merge[stim] = [] for dataset in dataset_list: @@ -114,6 +117,14 @@ def get_merged_data(self): msa = MSA() msa.cluster_ids = self.cluster_ids + try: + msa.z_scores = self.z_scores + msa.z_bins = self.z_bins + msa.z_windows = self.z_windows + except AttributeError: + pass + + return msa From 564f4068115f1634c2d47742c3e47e7f0a1d6a44 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 16 Oct 2023 16:11:21 -0400 Subject: [PATCH 03/25] add feature for counting prevalence of response types --- src/spikeanalysis/utils.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index 036ab2c..aab5a1d 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -74,3 +74,82 @@ def gaussian_smoothing(array: np.array, bin_size: float, std: float) -> np.array smoothed_array[row] = signal.convolve(array[row], smoothing_window, mode="same") / bin_size return smoothed_array + + +def prevalence_counts(responsive_neurons: dict | str | "Path", stim: list[str] | None = None, trial_index: dict | None=None, all_trials: bool = False, exclusive_list: list | None = None, inclusive_list: list | None = None): + + # prep responsive neurons from file or from argument + from pathlib import Path + if isinstance(responsive_neurons, (str, Path)): + responsive_neurons_path = Path(responsive_neurons) + assert responsive_neurons_path.is_file(), "responsive neuron json must exist" + with open(responsive_neurons_path, "r") as read_file: + responsive_neurons = json.load(read_file) + for stim in responsive_neurons.keys(): + for response in responsive_neurons[stim]: + responsive_neurons[stim][response] = np.array(responsive_neurons[stim][response], dtype=bool) + else: + assert isinstance(responsive_neurons, dict), f'responsive_neurons must be path or dict it is {type(responsive_neurons)}' + + # prep other arguments + if stim is None: + stim = list(responsive_neurons.keys()) + else: + assert isinstance(stim, list), 'stim must be a list of the desired keys' + + if trial_index is None: + trial_index = {} + for st in stim: + if all_trials: + trial_index[st] = 'all' + else: + trial_index[st] = 'any' + else: + assert isinstance(trial_index, dict), 'trial_index must be dict of which trials to use' + + if exclusive_list is None: + exclusive_list = [] + + if inclusive_list is None: + inclusive_list = [] + + # count final data + prevalence_dict = {} + for st in stim: + prevalence_dict[stim] = {} + response_types = responsive_neurons[st] + trial_indices = trial_index[st] + response_list = [] + response_labels = [] + for rt_label, rt in response_types.items(): + response_labels.append(rt_label) + if trial_indices == 'all': + response_list.append(mask = np.all(rt, axis=1)) + elif trial_indices == 'any': + response_list.append(mask = np.any(rt, axis=1)) + else: + if len(trial_indices)==2: + start, end = trial_indices[0], trial_indices[1] + response_list.append(mask = np.all(rt[:, start:end])) + else: + response_list.append(mask = np.all(rt[:, np.array(trial_indices)])) + + final_responses = np.vstack(response_list) + for response in exclusive_list: + rt_idx = response_labels.index(response) + pos_neuron_idx = np.nonzero(final_responses[rt_idx])[0] + keep_list = [] + for keep in inclusive_list: + keep_list.append(response_labels.index(keep)) + final_response_idx = np.array(keep_list.append(rt_idx)) + final_responses[pos_neuron_idx, ~final_response_idx] = False + + prevalences = np.sum(final_responses, axis=0) + prevalence_dict[stim]['labels'] = response_labels + prevalence_dict[stim]['counts'] = prevalences + + return prevalence_dict + + + + \ No newline at end of file From f854a2d8517331c7593ddb6e730ebfa9d18f1fcd Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 16 Oct 2023 16:12:50 -0400 Subject: [PATCH 04/25] linting --- src/spikeanalysis/merged_spike_analysis.py | 58 +++++++++------------- src/spikeanalysis/spike_data.py | 2 +- src/spikeanalysis/spike_plotter.py | 12 ++--- src/spikeanalysis/utils.py | 51 ++++++++++--------- 4 files changed, 59 insertions(+), 64 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index 1d1a623..1cf832a 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -8,6 +8,7 @@ from .spike_analysis import SpikeAnalysis from .curated_spike_analysis import CuratedSpikeAnalysis + @dataclass class MergedSpikeAnalysis: spikeanalysis_list: list @@ -17,63 +18,57 @@ def __post_init__(self): if self.name_list is None: pass else: - assert len(self.spikeanalysis_list)==len(self.name_list), "each dataset needs a name value" + assert len(self.spikeanalysis_list) == len(self.name_list), "each dataset needs a name value" - def add_analysis(self, analysis: SpikeAnalysis|CuratedSpikeAnalysis, name: str|None): + def add_analysis(self, analysis: SpikeAnalysis | CuratedSpikeAnalysis, name: str | None): self.spikeanalysis_list.append(analysis) if self.name_list is not None: - assert len(self.spikeanalysis_list) == len(self.name_list)-1, 'must provide name if other datasets named' + assert len(self.spikeanalysis_list) == len(self.name_list) - 1, "must provide name if other datasets named" self.name_list.append(name) else: - print('other datasets were not given names ignoring naming') + print("other datasets were not given names ignoring naming") - def merge(self, stim_name: str|None): - + def merge(self, stim_name: str | None): # merge the cluster_ids for plotting cluster_ids = [] for idx, sa in enumerate(self.spikeanalysis_list): - if isinstance(self.name_list,list): - sub_cluster_ids = [str(self.name_list[idx])+str(cid) for cid in sa.cluster_ids] + if isinstance(self.name_list, list): + sub_cluster_ids = [str(self.name_list[idx]) + str(cid) for cid in sa.cluster_ids] else: - sub_cluster_ids = [str(idx)+ str(cid) for cid in sa.cluster_ids] + sub_cluster_ids = [str(idx) + str(cid) for cid in sa.cluster_ids] cluster_ids.append(sub_cluster_ids) final_cluster_ids = [cid for cid in cluster_ids] self.cluster_ids = final_cluster_ids - #merge the events for plotting + # merge the events for plotting events = {} if stim_name is not None: events[stim_name] = self.spikeanalysis_list[0].events[stim_name] else: events = self.spikeanalysis_list[0].events - for idx, sa in enumerate(self.spikeanalysis_list): - z_score_list = [] try: z_score_list.append(sa.z_scores) - + z_bins = sa.z_bins z_windows = sa.z_windows have_z_data = True except AttributeError: if self.name_list is not None: - print(f'no z score data for data set {self.name_list[idx]}') + print(f"no z score data for data set {self.name_list[idx]}") fr_list = [] try: fr_list.append(sa.fr__) - except AttributeError: if self.name_list is not None: - print(f'no raw firing rate data for data set {self.name_list[idx]}') + print(f"no raw firing rate data for data set {self.name_list[idx]}") have_raw_data = False - - if len(z_score_list) >= 2: z_scores = _merge(z_score_list, stim_name=stim_name) @@ -82,21 +77,17 @@ def merge(self, stim_name: str|None): self.z_bins = z_bins self.z_windows = z_windows else: - self.z_bins={} + self.z_bins = {} self.z_bins[stim_name] = z_bins[stim_name] - self.z_windows={} + self.z_windows = {} self.z_windows[stim_name] = z_windows[stim_name] if len(fr_list) >= 2: fr_scores = _merge(fr_list, stim_name=stim_name) - - def _merge(dataset_list, stim_name): - data_merge = {} if stim_name is not None: - for stim in dataset_list[0].keys(): data_merge[stim] = [] for dataset in dataset_list: @@ -107,14 +98,11 @@ def _merge(dataset_list, stim_name): data_merge[stim_name].append(dataset[stim_name]) for stim, array in data_merge.items(): - data_merge[stim] = np.concatenate(array, axis=0) return data_merge - def get_merged_data(self): - msa = MSA() msa.cluster_ids = self.cluster_ids try: @@ -123,23 +111,25 @@ def get_merged_data(self): msa.z_windows = self.z_windows except AttributeError: pass - - return msa - + return msa class MSA(SpikeAnalysis): - def get_raw_psth(self): raise NotImplementedError + def z_score_data(self): raise NotImplementedError + def latencies(self): raise NotImplementedError + def set_spike_data(self): - print('data is immutable') + print("data is immutable") + def set_stimulus_data(self): - print('data is immutable') + print("data is immutable") + def get_raw_firing_rate(self): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 2586ff1..ce9e676 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -601,7 +601,7 @@ def set_qc(self): ) self._cids = self._cids[threshold] - print('qc metrics applied to cluster ids') + print("qc metrics applied to cluster ids") self.QC_RUN = True self._return_to_dir(current_dir) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 7bee189..a36fbbc 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -167,7 +167,7 @@ def plot_raw_firing( sorted_cluster_ids = self._plot_scores( data="raw-data", figsize=figsize, sorting_index=sorting_index, bar=bar, indices=indices, show_stim=show_stim ) - + self.cmap = None if indices: @@ -813,7 +813,7 @@ def plot_response_trace( response[neuron, trial, :], ebars=None, color=color, - stim=f'{stimulus}: {self.data.cluster_ids[neuron]}: {trial}', + stim=f"{stimulus}: {self.data.cluster_ids[neuron]}: {trial}", show_stim=show_stim, stim_lines=current_length, ) @@ -827,7 +827,7 @@ def plot_response_trace( avg_response, ebars=ebars, color=color, - stim=f'{stimulus}: neuron: {self.data.cluster_ids[neuron]}', + stim=f"{stimulus}: neuron: {self.data.cluster_ids[neuron]}", show_stim=show_stim, stim_lines=current_length, ) @@ -837,7 +837,7 @@ def plot_response_trace( avg_response, ebars=None, color=color, - stim=f'{stimulus}: neuron: {self.data.cluster_ids[neuron]}', + stim=f"{stimulus}: neuron: {self.data.cluster_ids[neuron]}", show_stim=show_stim, stim_lines=current_length, ) @@ -851,7 +851,7 @@ def plot_response_trace( avg_response, ebars=ebars, color=color, - stim=f'{stimulus} event number {trial}', + stim=f"{stimulus} event number {trial}", show_stim=show_stim, stim_lines=current_length, ) @@ -861,7 +861,7 @@ def plot_response_trace( avg_response, ebars=None, color=color, - stim=f'{stimulus} event number {trial}', + stim=f"{stimulus} event number {trial}", show_stim=show_stim, stim_lines=current_length, ) diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index aab5a1d..986ae0b 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -76,10 +76,17 @@ def gaussian_smoothing(array: np.array, bin_size: float, std: float) -> np.array return smoothed_array -def prevalence_counts(responsive_neurons: dict | str | "Path", stim: list[str] | None = None, trial_index: dict | None=None, all_trials: bool = False, exclusive_list: list | None = None, inclusive_list: list | None = None): - +def prevalence_counts( + responsive_neurons: dict | str | "Path", + stim: list[str] | None = None, + trial_index: dict | None = None, + all_trials: bool = False, + exclusive_list: list | None = None, + inclusive_list: list | None = None, +): # prep responsive neurons from file or from argument from pathlib import Path + if isinstance(responsive_neurons, (str, Path)): responsive_neurons_path = Path(responsive_neurons) assert responsive_neurons_path.is_file(), "responsive neuron json must exist" @@ -89,27 +96,29 @@ def prevalence_counts(responsive_neurons: dict | str | "Path", stim: list[str] | for response in responsive_neurons[stim]: responsive_neurons[stim][response] = np.array(responsive_neurons[stim][response], dtype=bool) else: - assert isinstance(responsive_neurons, dict), f'responsive_neurons must be path or dict it is {type(responsive_neurons)}' + assert isinstance( + responsive_neurons, dict + ), f"responsive_neurons must be path or dict it is {type(responsive_neurons)}" # prep other arguments if stim is None: stim = list(responsive_neurons.keys()) else: - assert isinstance(stim, list), 'stim must be a list of the desired keys' - + assert isinstance(stim, list), "stim must be a list of the desired keys" + if trial_index is None: trial_index = {} for st in stim: if all_trials: - trial_index[st] = 'all' + trial_index[st] = "all" else: - trial_index[st] = 'any' + trial_index[st] = "any" else: - assert isinstance(trial_index, dict), 'trial_index must be dict of which trials to use' + assert isinstance(trial_index, dict), "trial_index must be dict of which trials to use" if exclusive_list is None: exclusive_list = [] - + if inclusive_list is None: inclusive_list = [] @@ -123,16 +132,16 @@ def prevalence_counts(responsive_neurons: dict | str | "Path", stim: list[str] | response_labels = [] for rt_label, rt in response_types.items(): response_labels.append(rt_label) - if trial_indices == 'all': - response_list.append(mask = np.all(rt, axis=1)) - elif trial_indices == 'any': - response_list.append(mask = np.any(rt, axis=1)) + if trial_indices == "all": + response_list.append(mask=np.all(rt, axis=1)) + elif trial_indices == "any": + response_list.append(mask=np.any(rt, axis=1)) else: - if len(trial_indices)==2: + if len(trial_indices) == 2: start, end = trial_indices[0], trial_indices[1] - response_list.append(mask = np.all(rt[:, start:end])) + response_list.append(mask=np.all(rt[:, start:end])) else: - response_list.append(mask = np.all(rt[:, np.array(trial_indices)])) + response_list.append(mask=np.all(rt[:, np.array(trial_indices)])) final_responses = np.vstack(response_list) for response in exclusive_list: @@ -143,13 +152,9 @@ def prevalence_counts(responsive_neurons: dict | str | "Path", stim: list[str] | keep_list.append(response_labels.index(keep)) final_response_idx = np.array(keep_list.append(rt_idx)) final_responses[pos_neuron_idx, ~final_response_idx] = False - + prevalences = np.sum(final_responses, axis=0) - prevalence_dict[stim]['labels'] = response_labels - prevalence_dict[stim]['counts'] = prevalences + prevalence_dict[stim]["labels"] = response_labels + prevalence_dict[stim]["counts"] = prevalences return prevalence_dict - - - - \ No newline at end of file From 7d93447b297c0cc92e5ed6fb5974e56b922c0524 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 08:32:22 -0400 Subject: [PATCH 05/25] add new class fn to init --- src/spikeanalysis/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeanalysis/__init__.py b/src/spikeanalysis/__init__.py index 03d778e..b7a785d 100644 --- a/src/spikeanalysis/__init__.py +++ b/src/spikeanalysis/__init__.py @@ -5,8 +5,10 @@ from .intrinsic_plotter import IntrinsicPlotter from .analog_analysis import AnalogAnalysis from .curated_spike_analysis import CuratedSpikeAnalysis, read_responsive_neurons +from .merged_spike_analysis import MergedSpikeAnalysis, MSA from .stats_functions import kolmo_smir_stats from .plotting_functions import plot_piechart +from .utils import prevalence_counts import importlib.metadata From 5cb84be71368dc8bf4188e937fc0b1d08d6ec4a9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 08:32:28 -0400 Subject: [PATCH 06/25] docstrings --- src/spikeanalysis/merged_spike_analysis.py | 12 ++++++++- src/spikeanalysis/utils.py | 29 ++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index 1cf832a..0d47ba5 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional, Union import numpy as np @@ -11,6 +12,7 @@ @dataclass class MergedSpikeAnalysis: + """class for merging neurons from separate animals for plotting""" spikeanalysis_list: list name_list: list | None @@ -85,7 +87,7 @@ def merge(self, stim_name: str | None): if len(fr_list) >= 2: fr_scores = _merge(fr_list, stim_name=stim_name) - def _merge(dataset_list, stim_name): + def _merge(dataset_list: list, stim_name: str): data_merge = {} if stim_name is not None: for stim in dataset_list[0].keys(): @@ -116,6 +118,8 @@ def get_merged_data(self): class MSA(SpikeAnalysis): + """class for plotting merged datasets, but not for analysis""" + def get_raw_psth(self): raise NotImplementedError @@ -133,3 +137,9 @@ def set_stimulus_data(self): def get_raw_firing_rate(self): raise NotImplementedError + + def trial_correlation(self): + raise NotImplementedError + + def get_interspike_intervals(self): + raise NotImplementedError diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index 986ae0b..d8d605d 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -84,6 +84,35 @@ def prevalence_counts( exclusive_list: list | None = None, inclusive_list: list | None = None, ): + """ + Function for counting number of neurons with specific response properties for each stimulus + + Parameters + ---------- + responsive_neurons: dict | str | Path + Either a dictionary to assess with format of {stim: {response:array[bool]}} or the + path given as str of Path to the json containing the same structure + stim: list[str] | None, defualt: None + If only wanting to analyze a single stim or group of stim give as a list + None means analyze all stim + trial_index: dict | None, default: None + A dict containing the {stim: indices | all | any} + * if indices can be given as array of [start, stop] or as the indicies to use, eg. 1, 4,6 + * if all it will require all trial groups for a stim to be positive + * if any it will require at least one trial group of a stim to be positive + all_trials: bool, default False + Sets the trial_index to 'all' if true or 'any' if false. This is only used if trial_index is None + exclusive_list: list | None, default: None + The list of stimuli which are assessed in order. If given a neuron can only be in one of the categories + inclusive_list: list | None, deafult: None + This allows a neuron to be this category even if exclusive_list is provided + + Returns + ------- + prevalence_dict: dict + Dict of prevalence counts with each key being a stim and the values being + a 'labels' of response types 'counts' the prevalence counts + """ # prep responsive neurons from file or from argument from pathlib import Path From 04fa7e806c436bc135125346ef5d0b3b9403b08f Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 08:37:50 -0400 Subject: [PATCH 07/25] add check for multiple datasets --- src/spikeanalysis/merged_spike_analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index 0d47ba5..247dc76 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -32,6 +32,7 @@ def add_analysis(self, analysis: SpikeAnalysis | CuratedSpikeAnalysis, name: str def merge(self, stim_name: str | None): # merge the cluster_ids for plotting + assert len(self.spikeanalysis_list) >= 2, f"merge should only be run on multiple datasets you currently have {len(self.spikeanalysis_list} datasets" cluster_ids = [] for idx, sa in enumerate(self.spikeanalysis_list): if isinstance(self.name_list, list): From 9046b7c53e38497c724f8ad2a2550cd8b50b9829 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:22:54 -0400 Subject: [PATCH 08/25] fix check on adding analysis --- src/spikeanalysis/merged_spike_analysis.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index 247dc76..ccf830c 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -23,12 +23,14 @@ def __post_init__(self): assert len(self.spikeanalysis_list) == len(self.name_list), "each dataset needs a name value" def add_analysis(self, analysis: SpikeAnalysis | CuratedSpikeAnalysis, name: str | None): - self.spikeanalysis_list.append(analysis) + if self.name_list is not None: - assert len(self.spikeanalysis_list) == len(self.name_list) - 1, "must provide name if other datasets named" + assert len(self.spikeanalysis_list) == len(self.name_list), "must provide name if other datasets named" + self.spikeanalysis_list.append(analysis) self.name_list.append(name) else: print("other datasets were not given names ignoring naming") + self.spikeanalysis_list.append(analysis) def merge(self, stim_name: str | None): # merge the cluster_ids for plotting From 220c76836d67eb92eac69eafa527b55e1707ea25 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:57:11 -0400 Subject: [PATCH 09/25] add future to all functions just in case --- src/spikeanalysis/analysis_utils/histogram_functions.py | 1 + src/spikeanalysis/analysis_utils/latency_functions.py | 1 + src/spikeanalysis/plotbase.py | 1 + src/spikeanalysis/plotting_functions.py | 1 + src/spikeanalysis/spike_analysis.py | 2 ++ src/spikeanalysis/spike_data.py | 1 + src/spikeanalysis/stats_functions.py | 1 + src/spikeanalysis/utils.py | 1 + 8 files changed, 9 insertions(+) diff --git a/src/spikeanalysis/analysis_utils/histogram_functions.py b/src/spikeanalysis/analysis_utils/histogram_functions.py index c01714b..48e6e32 100644 --- a/src/spikeanalysis/analysis_utils/histogram_functions.py +++ b/src/spikeanalysis/analysis_utils/histogram_functions.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np from numba import jit import numba diff --git a/src/spikeanalysis/analysis_utils/latency_functions.py b/src/spikeanalysis/analysis_utils/latency_functions.py index 8d525e2..d0bbf2c 100644 --- a/src/spikeanalysis/analysis_utils/latency_functions.py +++ b/src/spikeanalysis/analysis_utils/latency_functions.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np from numba import jit import math diff --git a/src/spikeanalysis/plotbase.py b/src/spikeanalysis/plotbase.py index 9c51ec3..e3557af 100644 --- a/src/spikeanalysis/plotbase.py +++ b/src/spikeanalysis/plotbase.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional import matplotlib.pyplot as plt diff --git a/src/spikeanalysis/plotting_functions.py b/src/spikeanalysis/plotting_functions.py index 553c652..0f20136 100644 --- a/src/spikeanalysis/plotting_functions.py +++ b/src/spikeanalysis/plotting_functions.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional, Sequence import matplotlib.pyplot as plt import numpy as np diff --git a/src/spikeanalysis/spike_analysis.py b/src/spikeanalysis/spike_analysis.py index a781f38..c767626 100644 --- a/src/spikeanalysis/spike_analysis.py +++ b/src/spikeanalysis/spike_analysis.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Union, Optional import numpy as np @@ -881,6 +882,7 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None): self.responsive_neurons[stim][key] = responsive_neurons + def save_responsive_neurons(self): import json diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index ce9e676..316a4dd 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pathlib import Path from typing import Union, Optional import os diff --git a/src/spikeanalysis/stats_functions.py b/src/spikeanalysis/stats_functions.py index e81de73..48241f2 100644 --- a/src/spikeanalysis/stats_functions.py +++ b/src/spikeanalysis/stats_functions.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np from scipy.stats import ks_2samp from typing import Union diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index d8d605d..5373de5 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json from typing import Union import numpy as np From 5cb1ca83ee411d6e5e6097dfc6cbd766564ac6aa Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:58:54 -0400 Subject: [PATCH 10/25] fix neo deprecations --- src/spikeanalysis/stimulus_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeanalysis/stimulus_data.py b/src/spikeanalysis/stimulus_data.py index acd16f1..f87f6ad 100644 --- a/src/spikeanalysis/stimulus_data.py +++ b/src/spikeanalysis/stimulus_data.py @@ -147,7 +147,7 @@ def create_neo_reader(self, file_name: Optional[str] = None): if file_name is None: reader = neo.rawio.IntanRawIO(filename=self._filename) else: - neo_class = neo.rawio.get_rawio_class(file_name) + neo_class = neo.rawio.get_rawio(file_name) return neo_class reader.parse_header() From 82c1edf451dae4b7b6e45deef3c604c795e9a101 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:59:43 -0400 Subject: [PATCH 11/25] linting of files --- src/spikeanalysis/spike_analysis.py | 1 - src/spikeanalysis/utils.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeanalysis/spike_analysis.py b/src/spikeanalysis/spike_analysis.py index c767626..081dde9 100644 --- a/src/spikeanalysis/spike_analysis.py +++ b/src/spikeanalysis/spike_analysis.py @@ -882,7 +882,6 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None): self.responsive_neurons[stim][key] = responsive_neurons - def save_responsive_neurons(self): import json diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index 5373de5..658d4ae 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -87,11 +87,11 @@ def prevalence_counts( ): """ Function for counting number of neurons with specific response properties for each stimulus - + Parameters ---------- responsive_neurons: dict | str | Path - Either a dictionary to assess with format of {stim: {response:array[bool]}} or the + Either a dictionary to assess with format of {stim: {response:array[bool]}} or the path given as str of Path to the json containing the same structure stim: list[str] | None, defualt: None If only wanting to analyze a single stim or group of stim give as a list @@ -107,13 +107,13 @@ def prevalence_counts( The list of stimuli which are assessed in order. If given a neuron can only be in one of the categories inclusive_list: list | None, deafult: None This allows a neuron to be this category even if exclusive_list is provided - + Returns ------- prevalence_dict: dict Dict of prevalence counts with each key being a stim and the values being a 'labels' of response types 'counts' the prevalence counts - """ + """ # prep responsive neurons from file or from argument from pathlib import Path From dc7d8df9346e4c0d5ac683c5f90ae1d9d2ee5254 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:00:02 -0400 Subject: [PATCH 12/25] wip for merging --- src/spikeanalysis/merged_spike_analysis.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index ccf830c..f263232 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -13,6 +13,7 @@ @dataclass class MergedSpikeAnalysis: """class for merging neurons from separate animals for plotting""" + spikeanalysis_list: list name_list: list | None @@ -23,7 +24,6 @@ def __post_init__(self): assert len(self.spikeanalysis_list) == len(self.name_list), "each dataset needs a name value" def add_analysis(self, analysis: SpikeAnalysis | CuratedSpikeAnalysis, name: str | None): - if self.name_list is not None: assert len(self.spikeanalysis_list) == len(self.name_list), "must provide name if other datasets named" self.spikeanalysis_list.append(analysis) @@ -32,9 +32,11 @@ def add_analysis(self, analysis: SpikeAnalysis | CuratedSpikeAnalysis, name: str print("other datasets were not given names ignoring naming") self.spikeanalysis_list.append(analysis) - def merge(self, stim_name: str | None): + def merge(self, stim_name: str | None = None): # merge the cluster_ids for plotting - assert len(self.spikeanalysis_list) >= 2, f"merge should only be run on multiple datasets you currently have {len(self.spikeanalysis_list} datasets" + assert ( + len(self.spikeanalysis_list) >= 2 + ), f"merge should only be run on multiple datasets you currently have {len(self.spikeanalysis_list)} datasets" cluster_ids = [] for idx, sa in enumerate(self.spikeanalysis_list): if isinstance(self.name_list, list): @@ -42,7 +44,7 @@ def merge(self, stim_name: str | None): else: sub_cluster_ids = [str(idx) + str(cid) for cid in sa.cluster_ids] cluster_ids.append(sub_cluster_ids) - final_cluster_ids = [cid for cid in cluster_ids] + final_cluster_ids = [cid for sub_cid in cluster_ids for cid in sub_cid] self.cluster_ids = final_cluster_ids @@ -122,7 +124,7 @@ def get_merged_data(self): class MSA(SpikeAnalysis): """class for plotting merged datasets, but not for analysis""" - + def get_raw_psth(self): raise NotImplementedError @@ -140,9 +142,9 @@ def set_stimulus_data(self): def get_raw_firing_rate(self): raise NotImplementedError - + def trial_correlation(self): raise NotImplementedError - + def get_interspike_intervals(self): raise NotImplementedError From 39e272a3fd1f85ef7df175aef902f5a849066623 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:00:12 -0400 Subject: [PATCH 13/25] adding testing wip --- test/test_merged_spike_analysis.py | 67 ++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 test/test_merged_spike_analysis.py diff --git a/test/test_merged_spike_analysis.py b/test/test_merged_spike_analysis.py new file mode 100644 index 0000000..b27c77e --- /dev/null +++ b/test/test_merged_spike_analysis.py @@ -0,0 +1,67 @@ +import pytest +import numpy as np +from pathlib import Path + + +from spikeanalysis.merged_spike_analysis import MergedSpikeAnalysis, MSA +from spikeanalysis.stimulus_data import StimulusData +from spikeanalysis.spike_data import SpikeData +from spikeanalysis.spike_analysis import SpikeAnalysis + + +@pytest.fixture(scope="module") +def sa(): + directory = Path(__file__).parent.resolve() / "test_data" + stimulus = StimulusData(file_path=directory) + stimulus.create_neo_reader() + stimulus.get_analog_data() + stimulus.digitize_analog_data() + spikes = SpikeData(file_path=directory) + + spiketrain = SpikeAnalysis() + spiketrain.set_stimulus_data(stimulus) + spiketrain.set_spike_data(spikes) + return spiketrain + + +def test_init(sa): + test_msa = MergedSpikeAnalysis([sa, sa], name_list=None) + + assert isinstance(test_msa, MergedSpikeAnalysis) + + +def test_init_names(sa): + test_msa = MergedSpikeAnalysis(spikeanalysis_list=[sa, sa], name_list=["test", "test1"]) + + assert isinstance(test_msa, MergedSpikeAnalysis) + + +def test_init_failure(sa): + with pytest.raises(AssertionError): + test_msa = MergedSpikeAnalysis( + spikeanalysis_list=[sa, sa], + name_list=[ + "test", + ], + ) + + +def test_merge(sa): + test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) + test_msa.merge() + + assert isinstance(test_msa.cluster_ids, list) + print(test_msa.cluster_ids) + assert len(test_msa.cluster_ids) == 4 + + +def test_return_msa(sa): + test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) + test_msa.merge() + test_merged_sa = test_msa.get_merged_data() + + assert isinstance(test_merged_sa, MSA) + assert isinstance(test_merged_sa, SpikeAnalysis) + + with pytest.raises(NotImplementedError): + test_merged_sa.get_raw_psth() From 7de44c319ccdd0929e0e9cabacdfbc66ebd26b7c Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:01:09 -0400 Subject: [PATCH 14/25] fixes based on testing --- src/spikeanalysis/utils.py | 34 +++++++++++++++--------------- test/test_merged_spike_analysis.py | 17 +++++++++++++++ 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index 658d4ae..4b7a00e 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -84,7 +84,7 @@ def prevalence_counts( all_trials: bool = False, exclusive_list: list | None = None, inclusive_list: list | None = None, -): +) -> dict: """ Function for counting number of neurons with specific response properties for each stimulus @@ -122,9 +122,9 @@ def prevalence_counts( assert responsive_neurons_path.is_file(), "responsive neuron json must exist" with open(responsive_neurons_path, "r") as read_file: responsive_neurons = json.load(read_file) - for stim in responsive_neurons.keys(): - for response in responsive_neurons[stim]: - responsive_neurons[stim][response] = np.array(responsive_neurons[stim][response], dtype=bool) + for stimulus in responsive_neurons.keys(): + for response in responsive_neurons[stimulus]: + responsive_neurons[stimulus][response] = np.array(responsive_neurons[stimulus][response], dtype=bool) else: assert isinstance( responsive_neurons, dict @@ -155,7 +155,7 @@ def prevalence_counts( # count final data prevalence_dict = {} for st in stim: - prevalence_dict[stim] = {} + prevalence_dict[st] = {} response_types = responsive_neurons[st] trial_indices = trial_index[st] response_list = [] @@ -163,28 +163,28 @@ def prevalence_counts( for rt_label, rt in response_types.items(): response_labels.append(rt_label) if trial_indices == "all": - response_list.append(mask=np.all(rt, axis=1)) + response_list.append(np.all(rt, axis=1)) elif trial_indices == "any": - response_list.append(mask=np.any(rt, axis=1)) + response_list.append(np.any(rt, axis=1)) else: if len(trial_indices) == 2: start, end = trial_indices[0], trial_indices[1] - response_list.append(mask=np.all(rt[:, start:end])) + response_list.append(np.all(rt[:, start:end], axis=1)) else: - response_list.append(mask=np.all(rt[:, np.array(trial_indices)])) - + response_list.append(np.all(rt[:, np.array(trial_indices)], axis=1)) final_responses = np.vstack(response_list) for response in exclusive_list: rt_idx = response_labels.index(response) - pos_neuron_idx = np.nonzero(final_responses[rt_idx])[0] - keep_list = [] + pos_neuron_idx = np.array(np.nonzero(final_responses[rt_idx])[0]) + keep_list = [rt_idx] for keep in inclusive_list: keep_list.append(response_labels.index(keep)) - final_response_idx = np.array(keep_list.append(rt_idx)) - final_responses[pos_neuron_idx, ~final_response_idx] = False + final_response_idx = np.array(keep_list) + if len(final_response_idx) < np.shape(final_responses)[0] and len(pos_neuron_idx) > 0: + final_responses[~final_response_idx, pos_neuron_idx] = False - prevalences = np.sum(final_responses, axis=0) - prevalence_dict[stim]["labels"] = response_labels - prevalence_dict[stim]["counts"] = prevalences + prevalences = np.sum(final_responses, axis=1) + prevalence_dict[st]["labels"] = response_labels + prevalence_dict[st]["counts"] = prevalences return prevalence_dict diff --git a/test/test_merged_spike_analysis.py b/test/test_merged_spike_analysis.py index b27c77e..08bd231 100644 --- a/test/test_merged_spike_analysis.py +++ b/test/test_merged_spike_analysis.py @@ -45,6 +45,23 @@ def test_init_failure(sa): ], ) +def test_add_analysis(sa): + test_msa = MergedSpikeAnalysis(spikeanalysis_list=[sa, sa], name_list=["test", "test1"]) + test_msa.add_analysis(sa, 'test2') + + assert len(test_msa.spikeanalysis_list)==3 + + test_msa.add_analysis([sa, sa], ['test3', 'test4']) + + assert len(test_msa.spikeanalysis_list)==5 + assert 'test4' in test_msa.name_list + + test_msa_no_name = MergedSpikeAnalysis([sa, sa], name_list = None) + test_msa_no_name.add_analysis(sa, name=None) + + assert len(test_msa_no_name.spikeanalysis_list) == 3 + test_msa_no_name.add_analysis([sa, sa], name=None) + assert len(test_msa_no_name.spikeanalysis_list) == 5 def test_merge(sa): test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) From 38f63eead467923a6bd837dbab20860b8d41ff2f Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:01:41 -0400 Subject: [PATCH 15/25] more testing of msa --- test/test_merged_spike_analysis.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_merged_spike_analysis.py b/test/test_merged_spike_analysis.py index 08bd231..99bea7e 100644 --- a/test/test_merged_spike_analysis.py +++ b/test/test_merged_spike_analysis.py @@ -45,24 +45,26 @@ def test_init_failure(sa): ], ) + def test_add_analysis(sa): test_msa = MergedSpikeAnalysis(spikeanalysis_list=[sa, sa], name_list=["test", "test1"]) - test_msa.add_analysis(sa, 'test2') + test_msa.add_analysis(sa, "test2") - assert len(test_msa.spikeanalysis_list)==3 + assert len(test_msa.spikeanalysis_list) == 3 - test_msa.add_analysis([sa, sa], ['test3', 'test4']) + test_msa.add_analysis([sa, sa], ["test3", "test4"]) - assert len(test_msa.spikeanalysis_list)==5 - assert 'test4' in test_msa.name_list + assert len(test_msa.spikeanalysis_list) == 5 + assert "test4" in test_msa.name_list - test_msa_no_name = MergedSpikeAnalysis([sa, sa], name_list = None) + test_msa_no_name = MergedSpikeAnalysis([sa, sa], name_list=None) test_msa_no_name.add_analysis(sa, name=None) assert len(test_msa_no_name.spikeanalysis_list) == 3 test_msa_no_name.add_analysis([sa, sa], name=None) assert len(test_msa_no_name.spikeanalysis_list) == 5 + def test_merge(sa): test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) test_msa.merge() From cf5c58d357fcd0e1bc426c36e23fe6e0d0f31e8b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:01:51 -0400 Subject: [PATCH 16/25] allow for loading of lists --- src/spikeanalysis/merged_spike_analysis.py | 24 ++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index f263232..255ccce 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -23,14 +23,28 @@ def __post_init__(self): else: assert len(self.spikeanalysis_list) == len(self.name_list), "each dataset needs a name value" - def add_analysis(self, analysis: SpikeAnalysis | CuratedSpikeAnalysis, name: str | None): + def add_analysis( + self, + analysis: SpikeAnalysis | CuratedSpikeAnalysis | list[SpikeAnalysis, CuratedSpikeAnalysis], + name: str | None | list[str], + ): if self.name_list is not None: assert len(self.spikeanalysis_list) == len(self.name_list), "must provide name if other datasets named" - self.spikeanalysis_list.append(analysis) - self.name_list.append(name) + if isinstance(analysis, list): + assert isinstance(name, list), "if analysis is a list of analysis then name must be a list of names" + for idx, sa in enumerate(analysis): + self.spikeanalysis_list.append(sa) + self.name_list.append(name[idx]) + else: + self.spikeanalysis_list.append(analysis) + self.name_list.append(name) else: print("other datasets were not given names ignoring naming") - self.spikeanalysis_list.append(analysis) + if isinstance(analysis, list): + for sa in analysis: + self.spikeanalysis_list.append(sa) + else: + self.spikeanalysis_list.append(analysis) def merge(self, stim_name: str | None = None): # merge the cluster_ids for plotting @@ -55,6 +69,8 @@ def merge(self, stim_name: str | None = None): else: events = self.spikeanalysis_list[0].events + self.events = events + for idx, sa in enumerate(self.spikeanalysis_list): z_score_list = [] try: From 3c9ea4b43bc29684b39155981c8c855b9eedfdcc Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:02:02 -0400 Subject: [PATCH 17/25] add test for prevlaence counts --- test/test_utils.py | 90 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index fb19951..5ccd9a7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,10 @@ -from spikeanalysis.utils import verify_window_format, gaussian_smoothing, jsonify_parameters, NumpyEncoder +from spikeanalysis.utils import ( + verify_window_format, + gaussian_smoothing, + jsonify_parameters, + NumpyEncoder, + prevalence_counts, +) import json import os import pytest @@ -93,3 +99,85 @@ def test_updata_jsonify_parameters(tmp_path): for key, value in zip(["Test", "Test2"], [[1, 2, 3], [4, 5, 6]]): assert key in final_params.keys() assert value in final_params.values() + + +def test_prevalence_values(tmp_path): + resp_dict = { + "stim1": { + "sus": np.array([[True, True, True], [True, False, True], [False, False, False]]), + "inh": np.array( + [ + [False, False, False], + [ + True, + True, + True, + ], + [False, False, True], + ] + ), + }, + "stim2": { + "sus": np.array([[True, True, False], [True, False, True], [False, True, False]]), + "inh": np.array( + [ + [False, False, False], + [ + False, + False, + False, + ], + [False, False, False], + ] + ), + }, + } + + prev_dict = prevalence_counts(resp_dict) + + for key in ["stim1", "stim2"]: + assert key in prev_dict.keys(), f"{key} should be in prev_dict" + + for label in ["sus", "inh"]: + assert label in prev_dict["stim1"]["labels"], f"{label} should be a label" + print(prev_dict) + assert prev_dict["stim1"]["counts"][0] == 2 + + prev_dict = prevalence_counts(resp_dict, all_trials=True) + print(prev_dict) + assert prev_dict["stim1"]["counts"][0] == 1 + + prev_dict = prevalence_counts(resp_dict, stim=["stim2"]) + print(prev_dict) + assert "counts" in prev_dict["stim2"].keys() + try: + fail = prev_dict["stim1"] + assert False, "should have only selected stim2" + except KeyError: + pass + print(prev_dict) + prev_dict = prevalence_counts(resp_dict, trial_index={"stim1": [1], "stim2": [1]}) + print(prev_dict) + assert prev_dict["stim1"]["counts"][0] == 1 + prev_dict_slice = prevalence_counts(resp_dict, trial_index={"stim1": [1, 2], "stim2": [1, 2]}) + assert prev_dict["stim1"]["counts"][0] == prev_dict_slice["stim1"]["counts"][0] + print(prev_dict) + prev_dict = prevalence_counts(resp_dict, exclusive_list=["sus"]) + print(prev_dict) + assert prev_dict["stim1"]["counts"][1] == 1 + assert prev_dict["stim1"]["counts"][0] == 2 + + prev_dict = prevalence_counts(resp_dict, exclusive_list=["sus"], inclusive_list=["inh"]) + assert prev_dict["stim1"]["counts"][1] == 2 + assert prev_dict["stim1"]["counts"][0] == 2 + + dir = tmp_path / "test" + dir.mkdir() + file = dir / "responsive_neurons.json" + + with open(file, "w") as write_file: + json.dump(resp_dict, write_file, cls=NumpyEncoder) + + prev_dict_json = prevalence_counts(file, exclusive_list=["sus"], inclusive_list=["inh"]) + + assert prev_dict_json["stim1"]["counts"][0] == prev_dict["stim1"]["counts"][0] From 3f0643a51afc84f3238b8d0860582f42da5cc5ff Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 08:36:38 -0400 Subject: [PATCH 18/25] only load the psths --- src/spikeanalysis/merged_spike_analysis.py | 62 ++++------------------ 1 file changed, 11 insertions(+), 51 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index 255ccce..e414013 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -51,6 +51,8 @@ def merge(self, stim_name: str | None = None): assert ( len(self.spikeanalysis_list) >= 2 ), f"merge should only be run on multiple datasets you currently have {len(self.spikeanalysis_list)} datasets" + + assert isinstance(self.spikeanalysis_list[0].psths, dict), "must have psth to merge" cluster_ids = [] for idx, sa in enumerate(self.spikeanalysis_list): if isinstance(self.name_list, list): @@ -71,42 +73,13 @@ def merge(self, stim_name: str | None = None): self.events = events + psths_list = [] for idx, sa in enumerate(self.spikeanalysis_list): - z_score_list = [] - try: - z_score_list.append(sa.z_scores) - - z_bins = sa.z_bins - z_windows = sa.z_windows - have_z_data = True - - except AttributeError: - if self.name_list is not None: - print(f"no z score data for data set {self.name_list[idx]}") - - fr_list = [] - try: - fr_list.append(sa.fr__) - - except AttributeError: - if self.name_list is not None: - print(f"no raw firing rate data for data set {self.name_list[idx]}") - have_raw_data = False - - if len(z_score_list) >= 2: - z_scores = _merge(z_score_list, stim_name=stim_name) - self.z_scores = z_scores - if stim_name is None: - self.z_bins = z_bins - self.z_windows = z_windows - else: - self.z_bins = {} - self.z_bins[stim_name] = z_bins[stim_name] - self.z_windows = {} - self.z_windows[stim_name] = z_windows[stim_name] + psths_list.append(sa.psths) + + data_merge = _merge(psths_list, stim_name) - if len(fr_list) >= 2: - fr_scores = _merge(fr_list, stim_name=stim_name) + self.data = data_merge def _merge(dataset_list: list, stim_name: str): data_merge = {} @@ -128,12 +101,8 @@ def _merge(dataset_list: list, stim_name: str): def get_merged_data(self): msa = MSA() msa.cluster_ids = self.cluster_ids - try: - msa.z_scores = self.z_scores - msa.z_bins = self.z_bins - msa.z_windows = self.z_windows - except AttributeError: - pass + msa.events = self.events + msa.psths = self.data return msa @@ -144,23 +113,14 @@ class MSA(SpikeAnalysis): def get_raw_psth(self): raise NotImplementedError - def z_score_data(self): - raise NotImplementedError - - def latencies(self): - raise NotImplementedError - def set_spike_data(self): print("data is immutable") def set_stimulus_data(self): print("data is immutable") - def get_raw_firing_rate(self): - raise NotImplementedError - - def trial_correlation(self): + def get_interspike_intervals(self): raise NotImplementedError - def get_interspike_intervals(self): + def compute_event_interspike_intervals(self, time_ms: float = 200): raise NotImplementedError From 8b96eb54a847a4efddb11ef53f040dae00ebca17 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:53:57 -0400 Subject: [PATCH 19/25] add testing for merging data sets todo other values --- src/spikeanalysis/merged_spike_analysis.py | 159 ++++++++++++++++++--- test/test_merged_spike_analysis.py | 61 ++++++-- 2 files changed, 192 insertions(+), 28 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index e414013..651c485 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, Literal import numpy as np @@ -10,6 +10,9 @@ from .curated_spike_analysis import CuratedSpikeAnalysis +_merge_psth_values = ("zscore", "fr", "latencies", "isi", True) + + @dataclass class MergedSpikeAnalysis: """class for merging neurons from separate animals for plotting""" @@ -46,13 +49,25 @@ def add_analysis( else: self.spikeanalysis_list.append(analysis) - def merge(self, stim_name: str | None = None): + def merge( + self, psth: bool | list[Literal["zscore", "fr", "latencies", "isi"]] = True, stim_name: str | None = None + ): # merge the cluster_ids for plotting assert ( len(self.spikeanalysis_list) >= 2 ), f"merge should only be run on multiple datasets you currently have {len(self.spikeanalysis_list)} datasets" assert isinstance(self.spikeanalysis_list[0].psths, dict), "must have psth to merge" + + if not isinstance(psth, bool): + for category in psth: + assert category in ( + "zscore", + "fr", + "latencies", + "isi", + ), f"the only values you can use for psth are {_merge_psth_values}" + cluster_ids = [] for idx, sa in enumerate(self.spikeanalysis_list): if isinstance(self.name_list, list): @@ -73,17 +88,70 @@ def merge(self, stim_name: str | None = None): self.events = events - psths_list = [] - for idx, sa in enumerate(self.spikeanalysis_list): - psths_list.append(sa.psths) - - data_merge = _merge(psths_list, stim_name) - - self.data = data_merge - - def _merge(dataset_list: list, stim_name: str): + if psth == True: + psths_list = [] + for idx, sa in enumerate(self.spikeanalysis_list): + psths = sa.psths + merge_psth_dict = {} + psth_bins = {} + for sub_stim, psth_values in psths.items(): + merge_psth = psth_values["psth"] + bins = psth_values["bins"] + psth_bins[sub_stim] = bins + merge_psth_dict[sub_stim] = merge_psth + psths_list.append(merge_psth_dict) + + data_merge = self._merge(psths_list, stim_name) + + for key in data_merge.keys(): + if key in psth_bins.keys(): + final_psth = data_merge[key] + data_merge[key] = {} + data_merge[key]["bins"] = psth_bins[key] + data_merge[key]["psth"] = final_psth + + self.data = data_merge + self.use_psth = True + else: + self.use_psth = False + z_list = [] + fr_list = [] + lat_list = [] + isi_list = [] + for idx, sa in enumerate(self.spikeanalysis_list): + if "zscore" in psth: + z_list.append(sa.z_scores) + z_bins = sa.z_bins + z_windows = sa.z_windows + if "fr" in psth: + fr_list.append(sa.mean_firing_rate) + fr_bins = sa.fr_bins + + if "latencies" in psth: + raise NotImplementedError + if "isi" in psth: + raise NotImplementedError + + if len(z_list) != 0: + z_scores = self._merge(z_list, stim_name=stim_name) + self.z_scores = z_scores + self.z_bins = z_bins + self.z_windows = z_windows + + if len(fr_list) != 0: + final_fr = self._merge(fr_list, stim_name=stim_name) + self.mean_firing_rate = final_fr + self.fr_bins = fr_bins + + if len(lat_list) != 0: + raise NotImplementedError + + if len(isi_list) != 0: + raise NotImplementedError + + def _merge(self, dataset_list: list, stim_name: str): data_merge = {} - if stim_name is not None: + if stim_name is None: for stim in dataset_list[0].keys(): data_merge[stim] = [] for dataset in dataset_list: @@ -100,9 +168,34 @@ def _merge(dataset_list: list, stim_name: str): def get_merged_data(self): msa = MSA() - msa.cluster_ids = self.cluster_ids - msa.events = self.events - msa.psths = self.data + msa.set_cluster_ids(self.cluster_ids) + msa.set_events(self.events) + + if self.use_psth: + msa.psths = self.data + else: + try: + msa.z_scores = self.z_scores + msa.z_bins = self.z_bins + msa.z_windows = self.z_windows + except AttributeError: + pass + try: + msa.mean_firing_rate = self.mean_firing_rate + msa.fr_bins = self.fr_bins + except AttributeError: + pass + try: + msa.latency = self.latency + except AttributeError: + pass + try: + msa.isi = self.isi + msa.isi_values = self.isi_values + except AttributeError: + pass + + msa.use_psth = self.use_psth return msa @@ -110,6 +203,16 @@ def get_merged_data(self): class MSA(SpikeAnalysis): """class for plotting merged datasets, but not for analysis""" + def __init__(self): + self.use_psth = False + super().__init__() + + def set_cluster_ids(self, cluster_ids): + self.cluster_ids = cluster_ids + + def set_events(self, events): + self.events = events + def get_raw_psth(self): raise NotImplementedError @@ -119,8 +222,32 @@ def set_spike_data(self): def set_stimulus_data(self): print("data is immutable") + def z_score_data( + self, time_bin_ms: list[float] | float, bsl_window: list | list[list], z_window: list | list[list] + ): + if self.use_psth: + return super().z_score_data(time_bin_ms, bsl_window, z_window) + else: + raise NotImplementedError + + def get_raw_firing_rate( + self, + time_bin_ms: list[float] | float, + fr_window: list | list[list], + mode: str, + bsl_window: list | list[list] | None = None, + sm_time_ms: list[float] | float | None = None, + ): + if self.use_psth: + return super().get_raw_firing_rate(time_bin_ms, fr_window, mode, bsl_window, sm_time_ms) + else: + raise NotImplementedError + def get_interspike_intervals(self): raise NotImplementedError - def compute_event_interspike_intervals(self, time_ms: float = 200): + def compute_event_interspike_intervals(self): + raise NotImplementedError + + def autocorrelogram(self): raise NotImplementedError diff --git a/test/test_merged_spike_analysis.py b/test/test_merged_spike_analysis.py index 99bea7e..7132abc 100644 --- a/test/test_merged_spike_analysis.py +++ b/test/test_merged_spike_analysis.py @@ -65,22 +65,59 @@ def test_add_analysis(sa): assert len(test_msa_no_name.spikeanalysis_list) == 5 -def test_merge(sa): +def test_merge_psth(sa): + sa.events = { + "0": {"events": np.array([100]), "lengths": np.array([200]), "trial_groups": np.array([1]), "stim": "test"} + } + sa.get_raw_psth( + window=[0, 300], + time_bin_ms=50, + ) + test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) - test_msa.merge() - assert isinstance(test_msa.cluster_ids, list) - print(test_msa.cluster_ids) - assert len(test_msa.cluster_ids) == 4 + test_msa.merge(psth=True) + test_merged_msa = test_msa.get_merged_data() + assert isinstance(test_merged_msa.cluster_ids, list) + print(test_merged_msa.cluster_ids) + assert len(test_merged_msa.cluster_ids) == 4 -def test_return_msa(sa): - test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) - test_msa.merge() - test_merged_sa = test_msa.get_merged_data() + assert isinstance(test_merged_msa.events, dict) - assert isinstance(test_merged_sa, MSA) - assert isinstance(test_merged_sa, SpikeAnalysis) + psth = test_merged_msa.psths["test"]["psth"] + assert np.shape(psth) == (4, 1, 6000) + assert isinstance(test_merged_msa, SpikeAnalysis) + assert isinstance(test_merged_msa, MSA) + + with pytest.raises(NotImplementedError): + test_merged_msa.get_raw_psth() with pytest.raises(NotImplementedError): - test_merged_sa.get_raw_psth() + test_merged_msa.get_interspike_intervals() + with pytest.raises(NotImplementedError): + test_merged_msa.autocorrelogram() + + +def test_merge_z_score(sa): + sa.events = { + "0": {"events": np.array([100]), "lengths": np.array([200]), "trial_groups": np.array([1]), "stim": "test"} + } + sa.get_raw_psth( + window=[0, 300], + time_bin_ms=50, + ) + sa.z_score_data(time_bin_ms=1000, bsl_window=[0, 50], z_window=[0, 300]) + + test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) + + with pytest.raises(AssertionError): + test_msa.merge(psth=["zscoresa"]) + + test_msa.merge(psth=["zscore"]) + test_merged_msa = test_msa.get_merged_data() + + assert isinstance(test_merged_msa.z_scores, dict) + + test_merged_msa.set_stimulus_data() + test_merged_msa.set_spike_data() From 692f65bd791c877d85300076620cb80463cf1546 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:58:33 -0400 Subject: [PATCH 20/25] bumpy version add WIP docs --- docs/source/API.rst | 5 +++++ docs/source/conf.py | 2 +- docs/source/submodules/index.rst | 1 + docs/source/submodules/merged_spike_analysis.rst | 4 ++++ pyproject.toml | 2 +- test/test_merged_spike_analysis.py | 2 +- 6 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 docs/source/submodules/merged_spike_analysis.rst diff --git a/docs/source/API.rst b/docs/source/API.rst index 6483eda..d1e24d8 100644 --- a/docs/source/API.rst +++ b/docs/source/API.rst @@ -24,4 +24,9 @@ spikeanalysis .. autoclass:: AnalogAnalysis :members: + .. autoclass:: MergedSpikeAnalysis + :members: + .. autofunction:: kolmo_smir_stats + + .. autofunction:: prevalence_counts diff --git a/docs/source/conf.py b/docs/source/conf.py index 1eee2d4..43b172e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,7 +24,7 @@ author = "Zach McKenzie" # The full version, including alpha/beta/rc tags -release = "0.0.11" +release = "0.1.0" # -- General configuration --------------------------------------------------- diff --git a/docs/source/submodules/index.rst b/docs/source/submodules/index.rst index c45a175..f1df498 100644 --- a/docs/source/submodules/index.rst +++ b/docs/source/submodules/index.rst @@ -13,3 +13,4 @@ Submodules analog_analysis curated_spike_analysis functions + merged_spike_analysis diff --git a/docs/source/submodules/merged_spike_analysis.rst b/docs/source/submodules/merged_spike_analysis.rst new file mode 100644 index 0000000..255e9a5 --- /dev/null +++ b/docs/source/submodules/merged_spike_analysis.rst @@ -0,0 +1,4 @@ +MergedSpikeAnalysis +=================== + +Module for merging datasets \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9dafdb4..b98f80d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeanalysis" -version = '0.0.11' +version = '0.1.0' authors = [{name="Zach McKenzie", email="mineurs-torrent0x@icloud.com"}] description = "Analysis of Spike Trains" requires-python = ">=3.9" diff --git a/test/test_merged_spike_analysis.py b/test/test_merged_spike_analysis.py index 7132abc..8c2befc 100644 --- a/test/test_merged_spike_analysis.py +++ b/test/test_merged_spike_analysis.py @@ -118,6 +118,6 @@ def test_merge_z_score(sa): test_merged_msa = test_msa.get_merged_data() assert isinstance(test_merged_msa.z_scores, dict) - + test_merged_msa.set_stimulus_data() test_merged_msa.set_spike_data() From 25606f23e1486c27361b6af666ce46d15197ef3e Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:30:38 -0400 Subject: [PATCH 21/25] update docs --- .../submodules/curated_spike_analysis.rst | 14 ++++++ .../submodules/merged_spike_analysis.rst | 50 ++++++++++++++++++- docs/source/submodules/spike_data.rst | 16 ++++++ docs/source/submodules/stimulus_data.rst | 3 +- 4 files changed, 80 insertions(+), 3 deletions(-) diff --git a/docs/source/submodules/curated_spike_analysis.rst b/docs/source/submodules/curated_spike_analysis.rst index 6293be7..71081c8 100644 --- a/docs/source/submodules/curated_spike_analysis.rst +++ b/docs/source/submodules/curated_spike_analysis.rst @@ -61,3 +61,17 @@ To revert back to the original full set of neurons use :code:`revert_curation()` curated_st.revert_curation() +Plotting the Data +----------------- + +Since :code:`CuratedSpikeAnalysis` inherits from :code:`SpikeAnalysis` it can be used with +the :code:`SpikePlotter` class with no additional work. + +.. code-block:: python + + plotter = sa.SpikePlotter() + + plotter.set_analysis(curated_st) + plotter.plot_zscores() + + diff --git a/docs/source/submodules/merged_spike_analysis.rst b/docs/source/submodules/merged_spike_analysis.rst index 255e9a5..7352cd5 100644 --- a/docs/source/submodules/merged_spike_analysis.rst +++ b/docs/source/submodules/merged_spike_analysis.rst @@ -1,4 +1,52 @@ MergedSpikeAnalysis =================== -Module for merging datasets \ No newline at end of file +Module for merging datasets. Once data has been curated it may be beneficial to look at a series of +animals altogether. To facilitate this the MergedSpikeAnalysis object can be used. This is done in +similar fashion to other classes + +.. code-block:: python + + import spikeanalysis as sa + + # we start with SpikeAnalysis or CuratedSpikeAnalysis objects st1 + # and st2 + merged_data = sa.MergedSpikeAnalysis(spikeanalysis_list=[st1, st2], name_list=['animal1', 'animal2']) + + # if we need to add an animal, st3 we can use + merged_data.add_analysis(analysis=st3, name='animal3') + + # or we can use lists + merged_data.add_analysis(analysis=[st3,st4], name=['animal3', 'animal4']) + +Once the data to merge is ready to be merged one can use the :code:`merge()` function. This takes +in the value :code:`psth`, which can either be set to :code:`True` to mean to load a balanced +:code:`psths` values or can be a value in a list of potential merge values, e.g. :code:`zscore` or +for example :code:`fr`. + +.. code-block:: python + + # will attempt to merge the psths of each dataset + merged_data.merge(psth=True) + + # will attempt to merge z scores + merged_data.merge(psth=['zscore']) + +Note, that the datasets to be merged must be balanced. For example a dataset with 5 neurons, +10 trials, and 200 timepoints can only be merged to another dataset with :code:`x` neurons, 10 +trials, and 200 timepoints. The concatenation occurs at the level of the neuron axis (:code:`axis 0`) +so everything else must have the same dimensionality. + +Finally, the merged data set can be return for use in the :code:`SpikePlotter` class. + +.. code-block:: python + + msa = merged_data.get_merged_data() + plotter = sa.SpikePlotter() + plotter.set_analysis(msa) + +This works because the :code:`MSA` returned is a :code:`SpikeAnalysis` object that has specific +guardrails around methods which can no longer be accessed. For example, if the data was merged with +:code:`psth=True`, then z scores can be regenerated across the data with a different :code:`time_bin_ms`, +but if :code:`psth=['zscore']` was used then new z scores can be generated and the :code:`MSA` will +return a :code:`NotImplementedError` \ No newline at end of file diff --git a/docs/source/submodules/spike_data.rst b/docs/source/submodules/spike_data.rst index dfae651..9fefcf9 100644 --- a/docs/source/submodules/spike_data.rst +++ b/docs/source/submodules/spike_data.rst @@ -190,6 +190,20 @@ as understandable and maintainable a weighted average is used (faster, slightly \frac{1}{\Sigma amplitudes} \sum{amplitudes * depths} +Waveform amplitudes +^^^^^^^^^^^^^^^^^^^ + +Since neurons should always have the same amplitude we can assess the variation in amplitudes +as a measure of the quality of a neuron. We expect a rather gaussian distribution of amplitudes +so using the :code:`get_amplitudes()` function we assess how many spikes fall within a certain +std of the waveform data. + +.. code-block:: python + + spikes.get_amplitudes(std=2) # 2 std devs + + + Pipeline Function ----------------- @@ -204,6 +218,8 @@ parameters to be provided. Example below will all values included. idthres=20, # isolation distance 20--need an empiric number from your data rpv=0.02, # 2% the amount of spikes violating the 2ms refractory period allowed sil=0.45, # silhouette score (-1,1) with values above 0 indicates better and better clustering + amp_std=2 # number of std deviations above mean waveform amplitude to look at + amp_cutoff=0.98, # percent of neurons which must fall within amp_std deviations of mean waveform recurated= False, # I haven't recurated my data set_caching = True, # I want to save data for future use depth= 500, # probe inserted 500 um deep diff --git a/docs/source/submodules/stimulus_data.rst b/docs/source/submodules/stimulus_data.rst index 83ee7a7..59dbaa5 100644 --- a/docs/source/submodules/stimulus_data.rst +++ b/docs/source/submodules/stimulus_data.rst @@ -96,7 +96,7 @@ they can be returned using :code:`get_stimulus_channels`. Finally stimulus' shou stim_dict = stim.get_stimulus_channels() stim.set_trial_groups(trial_dictionary=trial_dictionary) # dict as explained above - sitm.set_stimulus_names(stim_names = name_dictionary) # same keys with str values + sitm.set_stimulus_name(stim_names = name_dictionary) # same keys with str values Train-based data @@ -143,7 +143,6 @@ generated stimulus data. To load it simply requires: stim.get_all_files() - Convenience Pipeline -------------------- From 72dc802545a57cdc9057a6be6ce23037a52382d9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:30:50 -0400 Subject: [PATCH 22/25] add save warning --- src/spikeanalysis/stimulus_data.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeanalysis/stimulus_data.py b/src/spikeanalysis/stimulus_data.py index f87f6ad..2cb994f 100644 --- a/src/spikeanalysis/stimulus_data.py +++ b/src/spikeanalysis/stimulus_data.py @@ -456,8 +456,11 @@ def save_events(self): os.chdir(self._file_path) try: - _ = self.digital_events - + digital_events = self.digital_events + for dig_channel, event_type in digital_events.items(): + assert ( + "stim" in event_type.keys() + ), f"Mst provide name for each stim using the the set_stimulus_name() function. Please do this for {dig_channel}" with open("digital_events.json", "w") as write_file: json.dump(self.digital_events, write_file, cls=NumpyEncoder) except AttributeError: From 7b88a14f24c466c15013cd8aa66c51cc25b80341 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:30:56 -0400 Subject: [PATCH 23/25] add docstrings --- src/spikeanalysis/merged_spike_analysis.py | 39 ++++++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index 651c485..2551ada 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -28,9 +28,17 @@ def __post_init__(self): def add_analysis( self, - analysis: SpikeAnalysis | CuratedSpikeAnalysis | list[SpikeAnalysis, CuratedSpikeAnalysis], + analysis: SpikeAnalysis | CuratedSpikeAnalysis | list[SpikeAnalysis | CuratedSpikeAnalysis], name: str | None | list[str], ): + """Function for adding an additional SpikeAnalysis to be merged + Parameters + ---------- + analysis: SpikeAnalysis | CuratedSpikeAnalysis | list[SpikeAnalysis|CuratedSpikeAnalysis] + The analysis or analyses to add + name: str | None | list[str] + The names of the animal/dataset that is being added in analysis + """ if self.name_list is not None: assert len(self.spikeanalysis_list) == len(self.name_list), "must provide name if other datasets named" if isinstance(analysis, list): @@ -52,6 +60,13 @@ def add_analysis( def merge( self, psth: bool | list[Literal["zscore", "fr", "latencies", "isi"]] = True, stim_name: str | None = None ): + """Function for merging different types of datasets between SpikeAnalysis objects + Parameters + ---------- + psth: bool | list['zscore'|'fr'|'latencies'|'isi'], default: True + Indicates whether merging should occur at the raw psth level or other datasets + stim_name: str | None, default None + Whether to just analyze would stimulus subtype""" # merge the cluster_ids for plotting assert ( len(self.spikeanalysis_list) >= 2 @@ -149,7 +164,8 @@ def merge( if len(isi_list) != 0: raise NotImplementedError - def _merge(self, dataset_list: list, stim_name: str): + def _merge(self, dataset_list: list, stim_name: str) -> dict: + """Internal funtion for merging datasets""" data_merge = {} if stim_name is None: for stim in dataset_list[0].keys(): @@ -167,9 +183,16 @@ def _merge(self, dataset_list: list, stim_name: str): return data_merge def get_merged_data(self): + """function for returning an + instatiation of an MSA class for plotting + + Returns + ------- + msa: spikeanalysis.MSA + mergedspikeanalysis data for analysis""" msa = MSA() - msa.set_cluster_ids(self.cluster_ids) - msa.set_events(self.events) + msa._set_cluster_ids(self.cluster_ids) + msa._set_events(self.events) if self.use_psth: msa.psths = self.data @@ -201,16 +224,18 @@ def get_merged_data(self): class MSA(SpikeAnalysis): - """class for plotting merged datasets, but not for analysis""" + """Class for plotting and some analysis of merged data""" def __init__(self): self.use_psth = False super().__init__() - def set_cluster_ids(self, cluster_ids): + def _set_cluster_ids(self, cluster_ids: np.array | list): + """used for setting the new cluster ids values""" self.cluster_ids = cluster_ids - def set_events(self, events): + def _set_events(self, events: dict): + """method for setting new events""" self.events = events def get_raw_psth(self): From 73a96f503ece77c0d83fdb53472812d8c3e15dc9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:35:17 -0400 Subject: [PATCH 24/25] fix init --- src/spikeanalysis/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeanalysis/__init__.py b/src/spikeanalysis/__init__.py index b7a785d..1d8d8b9 100644 --- a/src/spikeanalysis/__init__.py +++ b/src/spikeanalysis/__init__.py @@ -5,7 +5,7 @@ from .intrinsic_plotter import IntrinsicPlotter from .analog_analysis import AnalogAnalysis from .curated_spike_analysis import CuratedSpikeAnalysis, read_responsive_neurons -from .merged_spike_analysis import MergedSpikeAnalysis, MSA +from .merged_spike_analysis import MergedSpikeAnalysis from .stats_functions import kolmo_smir_stats from .plotting_functions import plot_piechart from .utils import prevalence_counts From dc574445ad4e2aeee6e940d07de3b67ea0cc07ec Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:57:13 -0400 Subject: [PATCH 25/25] add testing updates --- test/test_data/channel_map.npy | Bin 0 -> 144 bytes test/test_merged_spike_analysis.py | 18 ++++++++++++++++++ test/test_spike_data.py | 6 ++++++ 3 files changed, 24 insertions(+) create mode 100644 test/test_data/channel_map.npy diff --git a/test/test_data/channel_map.npy b/test/test_data/channel_map.npy new file mode 100644 index 0000000000000000000000000000000000000000..e5567c737aaf6e0cbd90c097bcc0675183fcfcaf GIT binary patch literal 144 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC%^qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= fXCxM+0{I#yI+{8PwF(pfE(RcA1Y%|&W&vUV3X2