Skip to content

Commit

Permalink
Merge branch 'main' into add-filter-auc
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 authored Dec 11, 2024
2 parents 538d4cb + 0497e55 commit 6e72d88
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/spikeanalysis/analysis_utils/histogram_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def convert_bins(bins: np.array, bin_number: np.int32) -> np.array:


@jit(nopython=True, cache=True)
def z_score_values(z_trial: numba.float32[:, :, :], mean_fr: numba.float32[:], std_fr: numba.float32[:]) -> np.array:
def z_score_values(z_trial: numba.float32[:, :, :], mean_fr: numba.float32[:,:], std_fr: numba.float32[:,:]) -> np.array:
z_trials = np.zeros(np.shape(z_trial))
for idx in range(len(mean_fr)):
for idy in range(np.shape(z_trial)[1]):
z_trials[idx, idy, :] = (z_trial[idx, idy] - mean_fr[idx]) / std_fr[idx]
z_trials[idx, idy, :] = (z_trial[idx, idy] - mean_fr[idx, idy]) / std_fr[idx, idy]

return z_trials
2 changes: 1 addition & 1 deletion src/spikeanalysis/curated_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def set_spike_analysis(self, st: SpikeAnalysis):

def set_mask(self, mask: list[bool]):

if len(mask) == len(self.cluster_ids):
if len(mask) != len(self.cluster_ids):
raise ValueError(
f"mask len {len(mask)} must be same as cluster ids len {len(self.cluster_ids)}. Maybe you need to revert curation."
)
Expand Down
1 change: 1 addition & 0 deletions src/spikeanalysis/plotbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ def _save_fig(self, fig, cluster_number, extra_title="", format="png"):

title = f"{cluster_number}_{extra_title}"
fig.savefig(title + "." + format, format=format)

109 changes: 103 additions & 6 deletions src/spikeanalysis/spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,9 @@ def get_raw_firing_rate(
self.fr_bins[stim] = bins[fr_window_values]
self.mean_firing_rate = final_fr

def zscore_data(self, time_bin_ms, bsl_window, z_window, eps):


def zscore_data(self, time_bin_ms, bsl_window, z_window, eps:float=0):

self.z_score_data(time_bin_ms=time_bin_ms, bsl_window=bsl_window, z_window=z_window, eps=eps)

Expand Down Expand Up @@ -472,6 +474,8 @@ def z_score_data(
self.z_windows = {}
self.z_bins = {}
self.raw_zscores = {}
self.keep_trials = {}
self.raw_baselines = {}
for idx, stim in enumerate(self.psths.keys()):
if self._verbose:
print(stim)
Expand Down Expand Up @@ -500,16 +504,61 @@ def z_score_data(
z_psth = psth[:, :, z_window_values]
z_scores[stim] = np.zeros(np.shape(z_psth))
self.raw_zscores[stim] = np.zeros(np.shape(z_psth))
self.keep_trials[stim] = {}

final_z_scores[stim] = np.zeros((np.shape(z_psth)[0], len(trial_set), np.shape(z_psth)[2]))

# use median instead for determining good trials
bsl_mean_global = np.median(
np.sum(bsl_psth, axis=2) / (bsl_current[1] - bsl_current[0]), axis=1
) # test median

bsl_std_global = (
np.median(
np.abs(np.sum(bsl_psth, axis=2) / (bsl_current[1] - bsl_current[0]) - bsl_mean_global[:, None]),
axis=1,
)
/ 0.6744897501960817
)

self.raw_baselines[stim] = np.sum(bsl_psth, axis=2) / (bsl_current[1] - bsl_current[0])
# to get baseline firing we do a per trial baseline for the neuron. To get an estimate
# we divide the baseline into 3 periods and iterate through those chunks of data to get
# the sub firing rate. Then we average those.
n_chunks = sum(bsl_values) // 3
for trial_number, trial in enumerate(tqdm(trial_set)):
self.keep_trials[stim][trial] = np.zeros((z_psth.shape[0], sum(trials == trial)), dtype=bool)
bsl_trial = bsl_psth[:, trials == trial, :]
mean_fr = np.mean(np.sum(bsl_trial, axis=2), axis=1) / ((bsl_current[1] - bsl_current[0]))
bsl_chunks = []
# iterate over baseline chunks and do sum to get point firing rate
for bsl_chunk_index in range(3):
bsl_chunk = bsl_trial[:, :, (bsl_chunk_index * n_chunks) : (bsl_chunk_index + 1) * n_chunks]
# neuron x trial x value
bsl_chunk_sum = np.sum(bsl_chunk, axis=2) / ((bsl_current[1] - bsl_current[0]) / 3)
bsl_chunks.append(bsl_chunk_sum)

# stack chunks in order to take the mean of the chunks
bsl_chunks = np.stack(bsl_chunks, axis=1)
mean_fr = np.mean(bsl_chunks, axis=1)
# for future computations may be beneficial to have small eps to std to prevent divide by 0
std_fr = np.std(np.sum(bsl_trial, axis=2), axis=1) / ((bsl_current[1] - bsl_current[0])) + eps
std_fr = np.std(bsl_chunks, axis=1) + eps

z_trial = z_psth[:, trials == trial, :] / time_bin_current
z_trials = hf.z_score_values(z_trial, mean_fr, std_fr)
z_scores[stim][:, trials == trial, :] = z_trials[:, :, :]
final_z_scores[stim][:, trial_number, :] = np.nanmean(z_trials, axis=1)
# if we are > 3 mads away from the tg mean then we eliminate a trial.
for neuron_bsl_idx in range(bsl_mean_global.shape[0]):
keep_trials = np.logical_and(
mean_fr[neuron_bsl_idx]
< (bsl_mean_global[neuron_bsl_idx] + (3 * bsl_std_global[neuron_bsl_idx])),
mean_fr[neuron_bsl_idx]
> (bsl_mean_global[neuron_bsl_idx] - (3 * bsl_std_global[neuron_bsl_idx])),
)
final_z_scores[stim][neuron_bsl_idx, trial_number, :] = np.nanmean(
z_trials[neuron_bsl_idx, keep_trials, :], axis=0
)

self.keep_trials[stim][trial][neuron_bsl_idx, :] = keep_trials
self.raw_zscores[stim][:, trials == trial, :] = z_trials[:, :, :]
self.z_bins[stim] = bins[z_window_values]
self.z_scores = final_z_scores
Expand Down Expand Up @@ -836,8 +885,39 @@ def autocorrelogram(self, time_ms: float = 500):

self.acg = acg

def calculate_baseline_values(self, mode: str = "mean"):

if not hasattr(self, "raw_baselines"):
raise ValueError("must run zscore_data in order to collect trial baselines")

if mode == "mean":
func = np.mean
elif mode == "median":
func = np.median
elif mode == "max":
func = np.max
elif callable(mode):
func = mode
else:
raise ValueError("enter a function or one of ['mean', 'median', 'max']")

baselines = {}
for stim, baseline in self.raw_baselines.items():
baselines[stim] = func(baseline, axis=1)

self.baselines = baselines

def return_value(self, value: str):
_values = ("z_scores", "raw_zscores", "mean_firing_rate", "raw_firing_rate", "correlations", "latency", "psths")
_values = (
"z_scores",
"raw_zscores",
"mean_firing_rate",
"raw_firing_rate",
"correlations",
"latency",
"psths",
"baselines",
)

if hasattr(self, value):
return getattr(self, value)
Expand Down Expand Up @@ -895,7 +975,7 @@ def save_z_parameters(self, z_parameters: dict, overwrite: bool = False):
with open(self._file_path / "z_parameters.json", "w") as write_file:
json.dump(z_parameters, write_file)

def get_responsive_neurons(self, z_parameters: Optional[dict] = None):
def get_responsive_neurons(self, z_parameters: Optional[dict] = None, latency_threshold_ms: Optional[dict] = None):
"""
function for assessing only responsive neurons based on z scored parameters.
Expand Down Expand Up @@ -940,12 +1020,16 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None):
else:
same_params = False

if latency_threshold_ms is None:
latency_threshold_ms = {k: None for k in self.z_scores.keys()}

self.responsive_neurons = {}
for stim in self.z_scores.keys():
self.responsive_neurons[stim] = {}
bins = self.z_bins[stim]
current_z_scores = self.z_scores[stim]

current_latency_threshold = latency_threshold_ms[stim]
if same_params:
current_z_params = z_parameters["all"]
else:
Expand All @@ -967,13 +1051,26 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None):
f"Not implemented for window of size {len(current_window)} possible lengths are 2 or 4"
)

current_bin_size = bins[1] - bins[0] # likely in ms
if current_latency_threshold is not None:
bins_to_threshold = int(current_latency_threshold // current_bin_size)
else:
bins_to_threshold = -1

current_z_scores_sub = current_z_scores[:, :, window_index]
bin_threshold_z_score = current_z_scores_sub[:, :, :bins_to_threshold]

# final should not be any. Need to think about how we want latency to be incorporated....
if current_score > 0 or "inhib" not in key.lower():
z_above_threshold = np.sum(np.where(current_z_scores_sub > current_score, 1, 0), axis=2)
latency_resp_neurons = np.any(np.where(bin_threshold_z_score > current_score, True, False), axis=2)
else:
z_above_threshold = np.sum(np.where(current_z_scores_sub < current_score, 1, 0), axis=2)
latency_resp_neurons = np.any(np.where(bin_threshold_z_score < current_score, True, False), axis=2)

responsive_neurons = np.where(z_above_threshold > current_n_bins, True, False)

responsive_neurons = np.logical_and(responsive_neurons, latency_resp_neurons)
self.responsive_neurons[stim][key] = responsive_neurons

def save_responsive_neurons(self, overwrite: bool = False):
Expand Down
4 changes: 4 additions & 0 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def _plot_scores(
sorted_z_scores = np.expand_dims(sorted_z_scores, axis=1)

# at baseline we need to eliminate cases of nan's, infinities, and 0's (if all the way across a stimulus)

nan_mask = np.any(
np.any(np.isnan(sorted_z_scores) | np.isinf(sorted_z_scores), axis=2)
| np.all(np.equal(sorted_z_scores, 0), axis=2),
Expand Down Expand Up @@ -452,6 +453,7 @@ def _plot_scores(
format=plot_kwargs.format,
)
elif plot_kwargs.save and plot_kwargs.title is None:

print("give title to save heat map")
plt.show()

Expand Down Expand Up @@ -610,6 +612,7 @@ def plot_raster(
plt.figure(dpi=plot_kwargs.dpi)
if plot_kwargs.save:
self._save_fig(fig, title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format)

plt.show()

def plot_sm_fr(
Expand Down Expand Up @@ -793,6 +796,7 @@ def plot_sm_fr(

if plot_kwargs.save:
self._save_fig(fig, title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format)

plt.show()

def plot_zscores_ind(self, z_bar: Optional[list[int]] = None, show_stim: bool = True):
Expand Down
8 changes: 4 additions & 4 deletions test/histogram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def test_z_score_values():
z_trial = np.ones((3, 4, 10), dtype=np.float32)
z_trial[0, 0, 9] = 10
z_trial[2, 2, 2] = -5
mean_fr = np.zeros(3, dtype=np.float32)
std_fr = np.ones(3, dtype=np.float32)
mean_fr = np.zeros((3,4), dtype=np.float32)
std_fr = np.ones((3,4), dtype=np.float32)

z_trials = hf.z_score_values(z_trial, mean_fr, std_fr)

Expand All @@ -219,8 +219,8 @@ def test_z_score_values():
assert z_trials[0, 0, 9] == 10
assert z_trials[2, 2, 2] == -5

mean_fr2 = np.ones(3, dtype=np.float32)
std_fr2 = 0.5 * np.ones(3, dtype=np.float32)
mean_fr2 = np.ones((3,4), dtype=np.float32)
std_fr2 = 0.5 * np.ones((3,4), dtype=np.float32)

z_trials_2 = hf.z_score_values(z_trial, mean_fr2, std_fr2)
assert z_trials_2[0, 0, 9] == 18
Expand Down
20 changes: 20 additions & 0 deletions test/test_curated_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,23 @@ def test_curation_both_trial(csa):
def test_curation_wrong_value(csa):
with pytest.raises(Exception):
csa.curate(criteria="test", by_stim=False, by_respone=False, by_trial=False)

def test_set_mask(csa):
csa.revert_curation()
csa.set_mask([True, True])

with pytest.raises(ValueError):
csa.set_mask([True, True, True])


def test_auc_filter(csa):
z_scores = {'test' :np.vstack((np.ones((1,2,100)), 2*np.ones((1,2,100))))}
z_bins = {'test': np.linspace(0,100,100)}
csa.z_scores = z_scores
csa.z_bins = z_bins

csa.filter_mask(window=[20,40], filter_params={'test': {'min':-40, 'max': 30}})

assert sum(csa.mask) ==1
csa.filter_mask(window=[20,40], filter_params={'test': {'min':-40, 'max': 80}})
assert sum(csa.mask) == 0
27 changes: 27 additions & 0 deletions test/test_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,30 @@ def test_autocorrelogram(sa):
assert np.shape(sa.acg) == (2, 24)

nptest.assert_array_equal(sa.acg[0], np.zeros((24,)))

def test_baselines(sa):

sa.events = {
"0": {
"events": np.array([100, 200]),
"lengths": np.array([100, 100]),
"trial_groups": np.array([1, 1]),
"stim": "test",
}
}
sa.get_raw_psth(window=[0, 300], time_bin_ms=50)

psths = sa.psths
psths["test"]["psth"][0, 0, 0:200] = 1
psths["test"]["psth"][0, 1, 100:300] = 2
psths["test"]["psth"][0, 0, 3000:4000] = 5
sa.psths = psths
# print(f"PSTH {sa.psths}")
sa.zscore_data(time_bin_ms=1000, bsl_window=[0, 50], z_window=[0, 300])

for func in ['mean', 'median', 'max']:
sa.calculate_baseline_values(func)

bsls = sa.baselines
assert isinstance(bsls, dict)
nptest.assert_allclose(bsls['test'],np.array([8.04,.1]))

0 comments on commit 6e72d88

Please sign in to comment.