Skip to content

Commit

Permalink
Cleaner code in utils
Browse files Browse the repository at this point in the history
  • Loading branch information
DradeAW committed Jan 10, 2024
1 parent 0db759f commit 5d17c9f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 22 deletions.
64 changes: 43 additions & 21 deletions src/lussac/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import scipy.stats
import numba
import numpy as np
import numpy.typing as npt
from .variables import Utils
from spikeinterface.curation.auto_merge import get_unit_adaptive_window, normalize_correlogram
from spikeinterface.postprocessing.correlograms import _compute_crosscorr_numba


def filter_kwargs(kwargs: dict[str, Any], function: Callable):
def filter_kwargs(kwargs: dict[str, Any], function: Callable) -> dict[str, Any]:
"""
Filters the kwargs to only keep the keys that are accepted by the function.
Expand Down Expand Up @@ -117,8 +118,18 @@ def merge_dict(d1: dict, d2: dict) -> dict:

def binom_sf(x: int, n: float, p: float) -> float:
"""
TODO
sf = survival function (1 - cdf).
Computes the survival function (sf = 1 - cdf) of the binomial distribution.
From values where the cdf is really close to 1.0, the survival function gives more precise results.
Allows for a non-integer n (uses interpolation).
@param x: int
The number of successes.
@param n: float
The number of trials.
@param p: float
The probability of success.
@return sf: float
The survival function of the binomial distribution.
"""

n_array = np.arange(math.floor(n-2), math.ceil(n+3), 1)
Expand All @@ -130,7 +141,7 @@ def binom_sf(x: int, n: float, p: float) -> float:
return f(n)


def gaussian_histogram(events: np.ndarray, t_axis: np.ndarray, sigma: float, truncate: float = 5., margin_reflect: bool = False) -> np.ndarray:
def gaussian_histogram(events: np.ndarray, t_axis: np.ndarray, sigma: float, truncate: float = 5., margin_reflect: bool = False) -> npt.NDArray[np.float32]:
"""
Computes a gaussian histogram for the given events.
For each point in time, take all the nearby events and compute the sum of their gaussian kernel.
Expand Down Expand Up @@ -172,7 +183,7 @@ def gaussian_histogram(events: np.ndarray, t_axis: np.ndarray, sigma: float, tru


@numba.jit((numba.float32[:], numba.float32[:], numba.float32, numba.float32), nopython=True, nogil=True, cache=True, parallel=True)
def _gaussian_kernel(events, t_axis, sigma, truncate):
def _gaussian_kernel(events, t_axis, sigma, truncate) -> npt.NDArray[np.float32]:
"""
Numba function for gaussian_histogram.
Expand Down Expand Up @@ -267,15 +278,15 @@ def estimate_cross_contamination(spike_train1: np.ndarray, spike_train2: np.ndar
# n and p for the binomial law for the number of coincidence (under the hypothesis of cross-contamination = limit).
n = N1 * N2 * ((1 - C1) * limit + C1)
p = 2 * t_r / Utils.t_max
p_value = binom_sf(n_violations - 1, n, p)
p_value = binom_sf(int(n_violations - 1), n, p)
if np.isnan(p_value): # pragma: no cover (should be unreachable).
raise ValueError(f"Could not compute p-value for cross-contamination:\n\tn_violations = {n_violations}\n\tn = {n}\n\tp = {p}")

return estimation, p_value


@numba.jit((numba.float32, ), nopython=True, nogil=True, cache=True)
def _get_border_probabilities(max_time):
def _get_border_probabilities(max_time) -> tuple[int, int, float, float]:
"""
Computes the integer borders, and the probability of 2 spikes distant by this border to be closer than max_time.
Expand All @@ -297,7 +308,7 @@ def _get_border_probabilities(max_time):


@numba.jit((numba.int64[:], numba.float32), nopython=True, nogil=True, cache=True)
def compute_nb_violations(spike_train, max_time):
def compute_nb_violations(spike_train, max_time) -> float:
"""
Computes the number of refractory period violations in a spike train.
Expand Down Expand Up @@ -334,7 +345,7 @@ def compute_nb_violations(spike_train, max_time):


@numba.jit((numba.int64[:], numba.int64[:], numba.float32), nopython=True, nogil=True, cache=True)
def compute_nb_coincidence(spike_train1, spike_train2, max_time):
def compute_nb_coincidence(spike_train1, spike_train2, max_time) -> float:
"""
Computes the number of coincident spikes between two spike trains.
Spike timings are integers, so their real timing follows a uniform distribution between t - dt/2 and t + dt/2.
Expand Down Expand Up @@ -382,7 +393,7 @@ def compute_nb_coincidence(spike_train1, spike_train2, max_time):
return n_coincident + p_high*n_coincident_high + p_low*n_coincident_low


def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, window: int, cross_shifts: np.ndarray | None = None) -> np.ndarray:
def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, window: int, cross_shifts: np.ndarray | None = None) -> npt.NDArray[np.int64]:
"""
Computes the number of coincident spikes between two sortings (given their spike vector).
Expand All @@ -393,6 +404,9 @@ def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vect
@param window: int
The coincidence window (in number of samples).
Two spikes separated by exactly window are considered as coincident.
@param cross_shifts: None | array[int32] (n_units1, n_units2)
If not None, the cross_shifts[i, j] is the shift between the spike times of the i-th unit of the first sorting
and the j-th unit of the second sorting.
@return coincidence_matrix: np.ndarray[int64] (n_units1, n_units2)
The coincidence matrix containing the number of coincident spikes between each pair of units.
"""
Expand All @@ -406,7 +420,7 @@ def compute_coincidence_matrix_from_vector(spike_vector1: np.ndarray, spike_vect

@numba.jit((numba.int64[:], numba.int64[:], numba.int64[:], numba.int64[:], numba.int32, numba.optional(numba.int32[:, :])),
nopython=True, nogil=True, cache=True)
def compute_coincidence_matrix(spike_times1, spike_labels1, spike_times2, spike_labels2, max_time, cross_shifts=None):
def compute_coincidence_matrix(spike_times1, spike_labels1, spike_times2, spike_labels2, max_time, cross_shifts=None) -> npt.NDArray[np.int64]:
"""
Computes the number of coincident spikes between all units in two sortings.
Expand Down Expand Up @@ -476,7 +490,7 @@ def compute_similarity_matrix(coincidence_matrix: np.ndarray, n_spikes1: np.ndar
return (similarity_matrix - expected_matrix) / (1 - expected_matrix)


def compute_cross_shift_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, max_shift: int, gaussian_std: float = 1.5):
def compute_cross_shift_from_vector(spike_vector1: np.ndarray, spike_vector2: np.ndarray, max_shift: int, gaussian_std: float = 1.5) -> npt.NDArray[np.int32]:
"""
Computes the shift between units pairwise between 2 sortings (given their spike vector).
Looks at their spike times and creates a cross-correlogram to look for a central peak.
Expand All @@ -498,7 +512,7 @@ def compute_cross_shift_from_vector(spike_vector1: np.ndarray, spike_vector2: np

@numba.jit((numba.int64[:], numba.int64[:], numba.int64[:], numba.int64[:], numba.int32, numba.float32),
nopython=True, nogil=True, cache=True, parallel=True)
def compute_cross_shift(spike_times1, spike_labels1, spike_times2, spike_labels2, max_shift, gaussian_std):
def compute_cross_shift(spike_times1, spike_labels1, spike_times2, spike_labels2, max_shift, gaussian_std) -> npt.NDArray[np.int32]:
"""
Computes the shift between units pairwise between 2 sortings.
Looks at their spike times and creates a cross-correlogram to look for a central peak.
Expand Down Expand Up @@ -604,14 +618,22 @@ def _create_fft_gaussian(N: int, cutoff_freq: float) -> np.ndarray:

def compute_correlogram_difference(auto_corr1: np.ndarray, auto_corr2: np.ndarray, cross_corr: np.ndarray, n1: int, n2: int) -> float:
"""
TODO.
@param auto_corr1:
@param auto_corr2:
@param cross_corr:
@param n1:
@param n2:
@return:
Code to compute the correlogram difference between two units.
The idea is to compare both auto-correlograms to the cross-correlogram,
weighted by the number of spikes in each unit (the unit with more spikes imposes its result).
@param auto_corr1: np.ndarray
The auto-correlogram of the first unit.
@param auto_corr2: np.ndarray
The auto-correlogram of the second unit.
@param cross_corr: np.ndarray
The cross-correlogram between the two units.
@param n1: int
The number of spikes in the first unit.
@param n2: int
The number of spikes in the second unit.
@return difference: float
The computed correlogram difference between both units (0.0 = they are similar).
"""

auto_corr1 = normalize_correlogram(auto_corr1)
Expand Down
2 changes: 1 addition & 1 deletion src/lussac/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def plot_units(wvf_extractor: si.WaveformExtractor, filepath: str, n_channels: i

template = wvf_extractor.get_template(unit_id, mode="average")
best_channels = np.argsort(np.max(np.abs(template), axis=0))[::-1]
for i in range(n_channels): # TODO: share y axis for all templates in a unit.
for i in range(n_channels): # TODO: share y axis for all templates in a unit.
channel = best_channels[i]
fig.add_trace(go.Scatter(
x=xaxis,
Expand Down

0 comments on commit 5d17c9f

Please sign in to comment.