Skip to content

Commit

Permalink
add testing for merging data sets todo other values
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Oct 19, 2023
1 parent 3f0643a commit 8b96eb5
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 28 deletions.
159 changes: 143 additions & 16 deletions src/spikeanalysis/merged_spike_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional, Union, Literal


import numpy as np
Expand All @@ -10,6 +10,9 @@
from .curated_spike_analysis import CuratedSpikeAnalysis


_merge_psth_values = ("zscore", "fr", "latencies", "isi", True)


@dataclass
class MergedSpikeAnalysis:
"""class for merging neurons from separate animals for plotting"""
Expand Down Expand Up @@ -46,13 +49,25 @@ def add_analysis(
else:
self.spikeanalysis_list.append(analysis)

def merge(self, stim_name: str | None = None):
def merge(
self, psth: bool | list[Literal["zscore", "fr", "latencies", "isi"]] = True, stim_name: str | None = None
):
# merge the cluster_ids for plotting
assert (
len(self.spikeanalysis_list) >= 2
), f"merge should only be run on multiple datasets you currently have {len(self.spikeanalysis_list)} datasets"

assert isinstance(self.spikeanalysis_list[0].psths, dict), "must have psth to merge"

if not isinstance(psth, bool):
for category in psth:
assert category in (
"zscore",
"fr",
"latencies",
"isi",
), f"the only values you can use for psth are {_merge_psth_values}"

cluster_ids = []
for idx, sa in enumerate(self.spikeanalysis_list):
if isinstance(self.name_list, list):
Expand All @@ -73,17 +88,70 @@ def merge(self, stim_name: str | None = None):

self.events = events

psths_list = []
for idx, sa in enumerate(self.spikeanalysis_list):
psths_list.append(sa.psths)

data_merge = _merge(psths_list, stim_name)

self.data = data_merge

def _merge(dataset_list: list, stim_name: str):
if psth == True:
psths_list = []
for idx, sa in enumerate(self.spikeanalysis_list):
psths = sa.psths
merge_psth_dict = {}
psth_bins = {}
for sub_stim, psth_values in psths.items():
merge_psth = psth_values["psth"]
bins = psth_values["bins"]
psth_bins[sub_stim] = bins
merge_psth_dict[sub_stim] = merge_psth
psths_list.append(merge_psth_dict)

data_merge = self._merge(psths_list, stim_name)

for key in data_merge.keys():
if key in psth_bins.keys():
final_psth = data_merge[key]
data_merge[key] = {}
data_merge[key]["bins"] = psth_bins[key]
data_merge[key]["psth"] = final_psth

self.data = data_merge
self.use_psth = True
else:
self.use_psth = False
z_list = []
fr_list = []
lat_list = []
isi_list = []
for idx, sa in enumerate(self.spikeanalysis_list):
if "zscore" in psth:
z_list.append(sa.z_scores)
z_bins = sa.z_bins
z_windows = sa.z_windows
if "fr" in psth:
fr_list.append(sa.mean_firing_rate)
fr_bins = sa.fr_bins

if "latencies" in psth:
raise NotImplementedError
if "isi" in psth:
raise NotImplementedError

if len(z_list) != 0:
z_scores = self._merge(z_list, stim_name=stim_name)
self.z_scores = z_scores
self.z_bins = z_bins
self.z_windows = z_windows

if len(fr_list) != 0:
final_fr = self._merge(fr_list, stim_name=stim_name)
self.mean_firing_rate = final_fr
self.fr_bins = fr_bins

if len(lat_list) != 0:
raise NotImplementedError

if len(isi_list) != 0:
raise NotImplementedError

def _merge(self, dataset_list: list, stim_name: str):
data_merge = {}
if stim_name is not None:
if stim_name is None:
for stim in dataset_list[0].keys():
data_merge[stim] = []
for dataset in dataset_list:
Expand All @@ -100,16 +168,51 @@ def _merge(dataset_list: list, stim_name: str):

def get_merged_data(self):
msa = MSA()
msa.cluster_ids = self.cluster_ids
msa.events = self.events
msa.psths = self.data
msa.set_cluster_ids(self.cluster_ids)
msa.set_events(self.events)

if self.use_psth:
msa.psths = self.data
else:
try:
msa.z_scores = self.z_scores
msa.z_bins = self.z_bins
msa.z_windows = self.z_windows
except AttributeError:
pass
try:
msa.mean_firing_rate = self.mean_firing_rate
msa.fr_bins = self.fr_bins
except AttributeError:
pass
try:
msa.latency = self.latency
except AttributeError:
pass
try:
msa.isi = self.isi
msa.isi_values = self.isi_values
except AttributeError:
pass

msa.use_psth = self.use_psth

return msa


class MSA(SpikeAnalysis):
"""class for plotting merged datasets, but not for analysis"""

def __init__(self):
self.use_psth = False
super().__init__()

def set_cluster_ids(self, cluster_ids):
self.cluster_ids = cluster_ids

def set_events(self, events):
self.events = events

def get_raw_psth(self):
raise NotImplementedError

Expand All @@ -119,8 +222,32 @@ def set_spike_data(self):
def set_stimulus_data(self):
print("data is immutable")

def z_score_data(
self, time_bin_ms: list[float] | float, bsl_window: list | list[list], z_window: list | list[list]
):
if self.use_psth:
return super().z_score_data(time_bin_ms, bsl_window, z_window)
else:
raise NotImplementedError

def get_raw_firing_rate(
self,
time_bin_ms: list[float] | float,
fr_window: list | list[list],
mode: str,
bsl_window: list | list[list] | None = None,
sm_time_ms: list[float] | float | None = None,
):
if self.use_psth:
return super().get_raw_firing_rate(time_bin_ms, fr_window, mode, bsl_window, sm_time_ms)
else:
raise NotImplementedError

def get_interspike_intervals(self):
raise NotImplementedError

def compute_event_interspike_intervals(self, time_ms: float = 200):
def compute_event_interspike_intervals(self):
raise NotImplementedError

def autocorrelogram(self):
raise NotImplementedError
61 changes: 49 additions & 12 deletions test/test_merged_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,59 @@ def test_add_analysis(sa):
assert len(test_msa_no_name.spikeanalysis_list) == 5


def test_merge(sa):
def test_merge_psth(sa):
sa.events = {
"0": {"events": np.array([100]), "lengths": np.array([200]), "trial_groups": np.array([1]), "stim": "test"}
}
sa.get_raw_psth(
window=[0, 300],
time_bin_ms=50,
)

test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"])
test_msa.merge()

assert isinstance(test_msa.cluster_ids, list)
print(test_msa.cluster_ids)
assert len(test_msa.cluster_ids) == 4
test_msa.merge(psth=True)
test_merged_msa = test_msa.get_merged_data()

assert isinstance(test_merged_msa.cluster_ids, list)
print(test_merged_msa.cluster_ids)
assert len(test_merged_msa.cluster_ids) == 4

def test_return_msa(sa):
test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"])
test_msa.merge()
test_merged_sa = test_msa.get_merged_data()
assert isinstance(test_merged_msa.events, dict)

assert isinstance(test_merged_sa, MSA)
assert isinstance(test_merged_sa, SpikeAnalysis)
psth = test_merged_msa.psths["test"]["psth"]
assert np.shape(psth) == (4, 1, 6000)

assert isinstance(test_merged_msa, SpikeAnalysis)
assert isinstance(test_merged_msa, MSA)

with pytest.raises(NotImplementedError):
test_merged_msa.get_raw_psth()
with pytest.raises(NotImplementedError):
test_merged_sa.get_raw_psth()
test_merged_msa.get_interspike_intervals()
with pytest.raises(NotImplementedError):
test_merged_msa.autocorrelogram()


def test_merge_z_score(sa):
sa.events = {
"0": {"events": np.array([100]), "lengths": np.array([200]), "trial_groups": np.array([1]), "stim": "test"}
}
sa.get_raw_psth(
window=[0, 300],
time_bin_ms=50,
)
sa.z_score_data(time_bin_ms=1000, bsl_window=[0, 50], z_window=[0, 300])

test_msa = MergedSpikeAnalysis([sa, sa], name_list=["test", "test1"])

with pytest.raises(AssertionError):
test_msa.merge(psth=["zscoresa"])

test_msa.merge(psth=["zscore"])
test_merged_msa = test_msa.get_merged_data()

assert isinstance(test_merged_msa.z_scores, dict)

test_merged_msa.set_stimulus_data()
test_merged_msa.set_spike_data()

0 comments on commit 8b96eb5

Please sign in to comment.