From 5d17c9f31f7b52e235b1a142563e99fe561f7d17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 10 Jan 2024 16:01:21 +0100 Subject: [PATCH] Cleaner code in utils --- src/lussac/utils/misc.py | 64 ++++++++++++++++++++++++------------ src/lussac/utils/plotting.py | 2 +- 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/lussac/utils/misc.py b/src/lussac/utils/misc.py index d9dd0a2..6e5e036 100644 --- a/src/lussac/utils/misc.py +++ b/src/lussac/utils/misc.py @@ -5,12 +5,13 @@ import scipy.stats import numba import numpy as np +import numpy.typing as npt from .variables import Utils from spikeinterface.curation.auto_merge import get_unit_adaptive_window, normalize_correlogram from spikeinterface.postprocessing.correlograms import _compute_crosscorr_numba -def filter_kwargs(kwargs: dict[str, Any], function: Callable): +def filter_kwargs(kwargs: dict[str, Any], function: Callable) -> dict[str, Any]: """ Filters the kwargs to only keep the keys that are accepted by the function. @@ -117,8 +118,18 @@ def merge_dict(d1: dict, d2: dict) -> dict: def binom_sf(x: int, n: float, p: float) -> float: """ - TODO - sf = survival function (1 - cdf). + Computes the survival function (sf = 1 - cdf) of the binomial distribution. + From values where the cdf is really close to 1.0, the survival function gives more precise results. + Allows for a non-integer n (uses interpolation). + + @param x: int + The number of successes. + @param n: float + The number of trials. + @param p: float + The probability of success. + @return sf: float + The survival function of the binomial distribution. """ n_array = np.arange(math.floor(n-2), math.ceil(n+3), 1) @@ -130,7 +141,7 @@ def binom_sf(x: int, n: float, p: float) -> float: return f(n) -def gaussian_histogram(events: np.ndarray, t_axis: np.ndarray, sigma: float, truncate: float = 5., margin_reflect: bool = False) -> np.ndarray: +def gaussian_histogram(events: np.ndarray, t_axis: np.ndarray, sigma: float, truncate: float = 5., margin_reflect: bool = False) -> npt.NDArray[np.float32]: """ Computes a gaussian histogram for the given events. For each point in time, take all the nearby events and compute the sum of their gaussian kernel. @@ -172,7 +183,7 @@ def gaussian_histogram(events: np.ndarray, t_axis: np.ndarray, sigma: float, tru @numba.jit((numba.float32[:], numba.float32[:], numba.float32, numba.float32), nopython=True, nogil=True, cache=True, parallel=True) -def _gaussian_kernel(events, t_axis, sigma, truncate): +def _gaussian_kernel(events, t_axis, sigma, truncate) -> npt.NDArray[np.float32]: """ Numba function for gaussian_histogram. @@ -267,7 +278,7 @@ def estimate_cross_contamination(spike_train1: np.ndarray, spike_train2: np.ndar # n and p for the binomial law for the number of coincidence (under the hypothesis of cross-contamination = limit). n = N1 * N2 * ((1 - C1) * limit + C1) p = 2 * t_r / Utils.t_max - p_value = binom_sf(n_violations - 1, n, p) + p_value = binom_sf(int(n_violations - 1), n, p) if np.isnan(p_value): # pragma: no cover (should be unreachable). raise ValueError(f"Could not compute p-value for cross-contamination:\n\tn_violations = {n_violations}\n\tn = {n}\n\tp = {p}") @@ -275,7 +286,7 @@ def estimate_cross_contamination(spike_train1: np.ndarray, spike_train2: np.ndar @numba.jit((numba.float32, ), nopython=True, nogil=True, cache=True) -def _get_border_probabilities(max_time): +def _get_border_probabilities(max_time) -> tuple[int, int, float, float]: """ Computes the integer borders, and the probability of 2 spikes distant by this border to be closer than max_time. @@ -297,7 +308,7 @@ def _get_border_probabilities(max_time): @numba.jit((numba.int64[:], numba.float32), nopython=True, nogil=True, cache=True) -def compute_nb_violations(spike_train, max_time): +def compute_nb_violations(spike_train, max_time) -> float: """ Computes the number of refractory period violations in a spike train. @@ -334,7 +345,7 @@ def compute_nb_violations(spike_train, max_time): @numba.jit((numba.int64[:], numba.int64[:], numba.float32), nopython=True, nogil=True, cache=True) -def compute_nb_coincidence(spike_train1, spike_train2, max_time): +def compute_nb_coincidence(spike_train1, spike_train2, max_time) -> float: """ Computes the number of coincident spikes between two spike trains. Spike timings are integers, so their real timing follows a uniform distribution between t - dt/2 and t + dt/2. @@ -382,7 +393,7 @@ def compute_nb_coincidence(spike_train1, spike_train2, max_time): return n_coincident + p_high*n_coincident_high + p_low*n_coincident_low -def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, window: int, cross_shifts: np.ndarray | None = None) -> np.ndarray: +def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, window: int, cross_shifts: np.ndarray | None = None) -> npt.NDArray[np.int64]: """ Computes the number of coincident spikes between two sortings (given their spike vector). @@ -393,6 +404,9 @@ def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vect @param window: int The coincidence window (in number of samples). Two spikes separated by exactly window are considered as coincident. + @param cross_shifts: None | array[int32] (n_units1, n_units2) + If not None, the cross_shifts[i, j] is the shift between the spike times of the i-th unit of the first sorting + and the j-th unit of the second sorting. @return coincidence_matrix: np.ndarray[int64] (n_units1, n_units2) The coincidence matrix containing the number of coincident spikes between each pair of units. """ @@ -406,7 +420,7 @@ def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vect @numba.jit((numba.int64[:], numba.int64[:], numba.int64[:], numba.int64[:], numba.int32, numba.optional(numba.int32[:, :])), nopython=True, nogil=True, cache=True) -def compute_coincidence_matrix(spike_times1, spike_labels1, spike_times2, spike_labels2, max_time, cross_shifts=None): +def compute_coincidence_matrix(spike_times1, spike_labels1, spike_times2, spike_labels2, max_time, cross_shifts=None) -> npt.NDArray[np.int64]: """ Computes the number of coincident spikes between all units in two sortings. @@ -476,7 +490,7 @@ def compute_similarity_matrix(coincidence_matrix: np.ndarray, n_spikes1: np.ndar return (similarity_matrix - expected_matrix) / (1 - expected_matrix) -def compute_cross_shift_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, max_shift: int, gaussian_std: float = 1.5): +def compute_cross_shift_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, max_shift: int, gaussian_std: float = 1.5) -> npt.NDArray[np.int32]: """ Computes the shift between units pairwise between 2 sortings (given their spike vector). Looks at their spike times and creates a cross-correlogram to look for a central peak. @@ -498,7 +512,7 @@ def compute_cross_shift_from_vector(spike_vector1: np.ndarray, spike_vector2: np @numba.jit((numba.int64[:], numba.int64[:], numba.int64[:], numba.int64[:], numba.int32, numba.float32), nopython=True, nogil=True, cache=True, parallel=True) -def compute_cross_shift(spike_times1, spike_labels1, spike_times2, spike_labels2, max_shift, gaussian_std): +def compute_cross_shift(spike_times1, spike_labels1, spike_times2, spike_labels2, max_shift, gaussian_std) -> npt.NDArray[np.int32]: """ Computes the shift between units pairwise between 2 sortings. Looks at their spike times and creates a cross-correlogram to look for a central peak. @@ -604,14 +618,22 @@ def _create_fft_gaussian(N: int, cutoff_freq: float) -> np.ndarray: def compute_correlogram_difference(auto_corr1: np.ndarray, auto_corr2: np.ndarray, cross_corr: np.ndarray, n1: int, n2: int) -> float: """ - TODO. - - @param auto_corr1: - @param auto_corr2: - @param cross_corr: - @param n1: - @param n2: - @return: + Code to compute the correlogram difference between two units. + The idea is to compare both auto-correlograms to the cross-correlogram, + weighted by the number of spikes in each unit (the unit with more spikes imposes its result). + + @param auto_corr1: np.ndarray + The auto-correlogram of the first unit. + @param auto_corr2: np.ndarray + The auto-correlogram of the second unit. + @param cross_corr: np.ndarray + The cross-correlogram between the two units. + @param n1: int + The number of spikes in the first unit. + @param n2: int + The number of spikes in the second unit. + @return difference: float + The computed correlogram difference between both units (0.0 = they are similar). """ auto_corr1 = normalize_correlogram(auto_corr1) diff --git a/src/lussac/utils/plotting.py b/src/lussac/utils/plotting.py index 72e7b9f..9b0aefc 100644 --- a/src/lussac/utils/plotting.py +++ b/src/lussac/utils/plotting.py @@ -208,7 +208,7 @@ def plot_units(wvf_extractor: si.WaveformExtractor, filepath: str, n_channels: i template = wvf_extractor.get_template(unit_id, mode="average") best_channels = np.argsort(np.max(np.abs(template), axis=0))[::-1] - for i in range(n_channels): # TODO: share y axis for all templates in a unit. + for i in range(n_channels): # TODO: share y axis for all templates in a unit. channel = best_channels[i] fig.add_trace(go.Scatter( x=xaxis,