diff --git a/src/spikeanalysis/merged_spike_analysis.py b/src/spikeanalysis/merged_spike_analysis.py index e414013..651c485 100644 --- a/src/spikeanalysis/merged_spike_analysis.py +++ b/src/spikeanalysis/merged_spike_analysis.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, Literal import numpy as np @@ -10,6 +10,9 @@ from .curated_spike_analysis import CuratedSpikeAnalysis +_merge_psth_values = ("zscore", "fr", "latencies", "isi", True) + + @dataclass class MergedSpikeAnalysis: """class for merging neurons from separate animals for plotting""" @@ -46,13 +49,25 @@ def add_analysis( else: self.spikeanalysis_list.append(analysis) - def merge(self, stim_name: str | None = None): + def merge( + self, psth: bool | list[Literal["zscore", "fr", "latencies", "isi"]] = True, stim_name: str | None = None + ): # merge the cluster_ids for plotting assert ( len(self.spikeanalysis_list) >= 2 ), f"merge should only be run on multiple datasets you currently have {len(self.spikeanalysis_list)} datasets" assert isinstance(self.spikeanalysis_list[0].psths, dict), "must have psth to merge" + + if not isinstance(psth, bool): + for category in psth: + assert category in ( + "zscore", + "fr", + "latencies", + "isi", + ), f"the only values you can use for psth are {_merge_psth_values}" + cluster_ids = [] for idx, sa in enumerate(self.spikeanalysis_list): if isinstance(self.name_list, list): @@ -73,17 +88,70 @@ def merge(self, stim_name: str | None = None): self.events = events - psths_list = [] - for idx, sa in enumerate(self.spikeanalysis_list): - psths_list.append(sa.psths) - - data_merge = _merge(psths_list, stim_name) - - self.data = data_merge - - def _merge(dataset_list: list, stim_name: str): + if psth == True: + psths_list = [] + for idx, sa in enumerate(self.spikeanalysis_list): + psths = sa.psths + merge_psth_dict = {} + psth_bins = {} + for sub_stim, psth_values in psths.items(): + merge_psth = psth_values["psth"] + bins = psth_values["bins"] + psth_bins[sub_stim] = bins + merge_psth_dict[sub_stim] = merge_psth + psths_list.append(merge_psth_dict) + + data_merge = self._merge(psths_list, stim_name) + + for key in data_merge.keys(): + if key in psth_bins.keys(): + final_psth = data_merge[key] + data_merge[key] = {} + data_merge[key]["bins"] = psth_bins[key] + data_merge[key]["psth"] = final_psth + + self.data = data_merge + self.use_psth = True + else: + self.use_psth = False + z_list = [] + fr_list = [] + lat_list = [] + isi_list = [] + for idx, sa in enumerate(self.spikeanalysis_list): + if "zscore" in psth: + z_list.append(sa.z_scores) + z_bins = sa.z_bins + z_windows = sa.z_windows + if "fr" in psth: + fr_list.append(sa.mean_firing_rate) + fr_bins = sa.fr_bins + + if "latencies" in psth: + raise NotImplementedError + if "isi" in psth: + raise NotImplementedError + + if len(z_list) != 0: + z_scores = self._merge(z_list, stim_name=stim_name) + self.z_scores = z_scores + self.z_bins = z_bins + self.z_windows = z_windows + + if len(fr_list) != 0: + final_fr = self._merge(fr_list, stim_name=stim_name) + self.mean_firing_rate = final_fr + self.fr_bins = fr_bins + + if len(lat_list) != 0: + raise NotImplementedError + + if len(isi_list) != 0: + raise NotImplementedError + + def _merge(self, dataset_list: list, stim_name: str): data_merge = {} - if stim_name is not None: + if stim_name is None: for stim in dataset_list[0].keys(): data_merge[stim] = [] for dataset in dataset_list: @@ -100,9 +168,34 @@ def _merge(dataset_list: list, stim_name: str): def get_merged_data(self): msa = MSA() - msa.cluster_ids = self.cluster_ids - msa.events = self.events - msa.psths = self.data + msa.set_cluster_ids(self.cluster_ids) + msa.set_events(self.events) + + if self.use_psth: + msa.psths = self.data + else: + try: + msa.z_scores = self.z_scores + msa.z_bins = self.z_bins + msa.z_windows = self.z_windows + except AttributeError: + pass + try: + msa.mean_firing_rate = self.mean_firing_rate + msa.fr_bins = self.fr_bins + except AttributeError: + pass + try: + msa.latency = self.latency + except AttributeError: + pass + try: + msa.isi = self.isi + msa.isi_values = self.isi_values + except AttributeError: + pass + + msa.use_psth = self.use_psth return msa @@ -110,6 +203,16 @@ def get_merged_data(self): class MSA(SpikeAnalysis): """class for plotting merged datasets, but not for analysis""" + def __init__(self): + self.use_psth = False + super().__init__() + + def set_cluster_ids(self, cluster_ids): + self.cluster_ids = cluster_ids + + def set_events(self, events): + self.events = events + def get_raw_psth(self): raise NotImplementedError @@ -119,8 +222,32 @@ def set_spike_data(self): def set_stimulus_data(self): print("data is immutable") + def z_score_data( + self, time_bin_ms: list[float] | float, bsl_window: list | list[list], z_window: list | list[list] + ): + if self.use_psth: + return super().z_score_data(time_bin_ms, bsl_window, z_window) + else: + raise NotImplementedError + + def get_raw_firing_rate( + self, + time_bin_ms: list[float] | float, + fr_window: list | list[list], + mode: str, + bsl_window: list | list[list] | None = None, + sm_time_ms: list[float] | float | None = None, + ): + if self.use_psth: + return super().get_raw_firing_rate(time_bin_ms, fr_window, mode, bsl_window, sm_time_ms) + else: + raise NotImplementedError + def get_interspike_intervals(self): raise NotImplementedError - def compute_event_interspike_intervals(self, time_ms: float = 200): + def compute_event_interspike_intervals(self): + raise NotImplementedError + + def autocorrelogram(self): raise NotImplementedError diff --git a/test/test_merged_spike_analysis.py b/test/test_merged_spike_analysis.py index 99bea7e..7132abc 100644 --- a/test/test_merged_spike_analysis.py +++ b/test/test_merged_spike_analysis.py @@ -65,22 +65,59 @@ def test_add_analysis(sa): assert len(test_msa_no_name.spikeanalysis_list) == 5 -def test_merge(sa): +def test_merge_psth(sa): + sa.events = { + "0": {"events": np.array([100]), "lengths": np.array([200]), "trial_groups": np.array([1]), "stim": "test"} + } + sa.get_raw_psth( + window=[0, 300], + time_bin_ms=50, + ) + test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) - test_msa.merge() - assert isinstance(test_msa.cluster_ids, list) - print(test_msa.cluster_ids) - assert len(test_msa.cluster_ids) == 4 + test_msa.merge(psth=True) + test_merged_msa = test_msa.get_merged_data() + assert isinstance(test_merged_msa.cluster_ids, list) + print(test_merged_msa.cluster_ids) + assert len(test_merged_msa.cluster_ids) == 4 -def test_return_msa(sa): - test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) - test_msa.merge() - test_merged_sa = test_msa.get_merged_data() + assert isinstance(test_merged_msa.events, dict) - assert isinstance(test_merged_sa, MSA) - assert isinstance(test_merged_sa, SpikeAnalysis) + psth = test_merged_msa.psths["test"]["psth"] + assert np.shape(psth) == (4, 1, 6000) + assert isinstance(test_merged_msa, SpikeAnalysis) + assert isinstance(test_merged_msa, MSA) + + with pytest.raises(NotImplementedError): + test_merged_msa.get_raw_psth() with pytest.raises(NotImplementedError): - test_merged_sa.get_raw_psth() + test_merged_msa.get_interspike_intervals() + with pytest.raises(NotImplementedError): + test_merged_msa.autocorrelogram() + + +def test_merge_z_score(sa): + sa.events = { + "0": {"events": np.array([100]), "lengths": np.array([200]), "trial_groups": np.array([1]), "stim": "test"} + } + sa.get_raw_psth( + window=[0, 300], + time_bin_ms=50, + ) + sa.z_score_data(time_bin_ms=1000, bsl_window=[0, 50], z_window=[0, 300]) + + test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"]) + + with pytest.raises(AssertionError): + test_msa.merge(psth=["zscoresa"]) + + test_msa.merge(psth=["zscore"]) + test_merged_msa = test_msa.get_merged_data() + + assert isinstance(test_merged_msa.z_scores, dict) + + test_merged_msa.set_stimulus_data() + test_merged_msa.set_spike_data()