Skip to content

Commit

Permalink
Further optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
DradeAW committed Jan 15, 2024
1 parent 51d1941 commit 3a1b717
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 20 deletions.
7 changes: 3 additions & 4 deletions src/lussac/modules/merge_sortings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def default_params(self) -> dict[str, Any]:
'wvf_extraction': {
'ms_before': 1.0,
'ms_after': 2.0,
'max_spikes_per_unit': 1_000
'max_spikes_per_unit': 1_000,
'filter': [250.0, 6_000.0]
},
'filter': [250.0, 6_000.0],
'num_channels': 5
},
'merge_check': {
Expand All @@ -67,7 +67,7 @@ def update_params(self, params: dict[str, Any]) -> dict[str, Any]:

@override
def run(self, params: dict[str, Any]) -> dict[str, si.BaseSorting]:
self.aggregated_wvf_extractor = self.extract_waveforms(filter=params['waveform_validation']['filter'], sparse=False, **params['waveform_validation']['wvf_extraction'])
self.aggregated_wvf_extractor = self.extract_waveforms(sparse=False, **params['waveform_validation']['wvf_extraction'])
cross_shifts = self.compute_cross_shifts(params['max_shift'])

similarity_matrices = self._compute_similarity_matrices(cross_shifts, params)
Expand Down Expand Up @@ -163,7 +163,6 @@ def _compute_graph(self, similarity_matrices: dict[str, dict[str, np.ndarray]],
"""

censored_period, refractory_period = params['refractory_period']
min_f, max_f = params['waveform_validation']['filter']

# Populating the graph with all the nodes (i.e. all the units) with properties.
graph = nx.Graph()
Expand Down
3 changes: 1 addition & 2 deletions src/lussac/modules/merge_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def default_params(self) -> dict[str, Any]:
'ms_before': 1.0,
'ms_after': 1.5,
'max_spikes_per_unit': 2_000,
'sparse': False,
'filter': [100, 9000]
},
'auto_merge_params': {
Expand All @@ -49,7 +48,7 @@ def update_params(self, params: dict[str, Any]) -> dict[str, Any]:

@override
def run(self, params: dict[str, Any]) -> si.BaseSorting:
wvf_extractor = self.extract_waveforms(**params['wvf_extraction'])
wvf_extractor = self.extract_waveforms(sparse=False, **params['wvf_extraction'])
potential_merges, extra_outputs = scur.get_potential_auto_merge(wvf_extractor, extra_outputs=True, **params['auto_merge_params'])

sorting = self._remove_splits(self.sorting, extra_outputs, params)
Expand Down
2 changes: 1 addition & 1 deletion src/lussac/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .misc import (filter_kwargs, flatten_dict, unflatten_dict, merge_dict, binom_sf, gaussian_histogram, estimate_contamination, estimate_cross_contamination, compute_nb_coincidence,
from .misc import (gaussian_pdf, filter_kwargs, flatten_dict, unflatten_dict, merge_dict, binom_sf, gaussian_histogram, estimate_contamination, estimate_cross_contamination, compute_nb_coincidence,
compute_coincidence_matrix, compute_coincidence_matrix_from_vector, compute_similarity_matrix, filter, compute_cross_shift_from_vector,
compute_cross_shift, compute_correlogram_difference)
from .plotting import get_path_to_plotlyJS, export_figure, plot_sliders, plot_units, create_gt_annotations, create_graph_plot
Expand Down
71 changes: 61 additions & 10 deletions src/lussac/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import math
from typing import Any, Callable
from typing import Any, Callable, TypeVar
import scipy.interpolate
import scipy.stats
import numba
Expand All @@ -11,6 +11,27 @@
from spikeinterface.postprocessing.correlograms import _compute_crosscorr_numba


T = TypeVar("T", bound=npt.ArrayLike)


def gaussian_pdf(x: T, mu: float = 0.0, sigma: float = 1.0) -> T:
"""
Computes the pdf of a Normal distribution.
On my machine, is ~8x faster than scipy.stats.norm.pdf.
@param x: ArrayLike
The number or array on which to compute the pdf.
@param mu: float
The mean of the Normal distribution.
@param sigma: float
The standard deviation of the Normal distribution.
@return gaussian_pdf: ArrayLike
The computed pdf.
"""

return 1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x - mu)**2 / (2 * sigma**2))


def filter_kwargs(kwargs: dict[str, Any], function: Callable) -> dict[str, Any]:
"""
Filters the kwargs to only keep the keys that are accepted by the function.
Expand Down Expand Up @@ -215,6 +236,39 @@ def _gaussian_kernel(events, t_axis, sigma, truncate) -> npt.NDArray[np.float32]
return histogram / (sigma * np.sqrt(2*np.pi))


@numba.jit((numba.int64[:], numba.int64[:]), nopython=True, nogil=True, cache=True)
def spike_vector_to_spike_trains(sample_indices, unit_indices) -> list[np.ndarray[np.int64]]:
"""
Converts a spike vector to a list of spike trains in a really fast manner.
@param sample_indices: array[int64] (n_spikes1)
All the spike timings.
@param unit_indices: array[int64] (n_spikes1)
The unit labels (i.e. unit index of each spike).
@return spike_trains: list[array[int64]]
The list of spike trains.
"""

num_units = 1 + np.max(unit_indices)
num_spikes = sample_indices.size

num_spikes_per_unit = np.zeros(num_units, dtype=np.int32)
for s in range(num_spikes):
num_spikes_per_unit[unit_indices[s]] += 1

spike_trains = []
for u in range(num_units):
spike_trains.append(np.empty(num_spikes_per_unit[u], dtype=np.int64))

current_x = np.zeros(num_units, dtype=np.int32)
for s in range(num_spikes):
unit_index = unit_indices[s]
spike_trains[unit_index][current_x[unit_index]] = sample_indices[s]
current_x[unit_index] += 1

return spike_trains


def estimate_contamination(spike_train: np.ndarray, refractory_period: tuple[float, float]) -> float:
"""
Estimates the contamination of a spike train by looking at the number of refractory period violations.
Expand Down Expand Up @@ -507,7 +561,8 @@ def compute_cross_shift_from_vector(spike_vector1: np.ndarray, spike_vector2: np
The cross-shift matrix containing the shift between each pair of units.
"""

return compute_cross_shift(spike_vector1['sample_index'], spike_vector1['unit_index'], spike_vector2['sample_index'], spike_vector2['unit_index'], max_shift, gaussian_std)
return compute_cross_shift(spike_vector1['sample_index'].astype(np.int64, copy=False), spike_vector1['unit_index'].astype(np.int64, copy=False),
spike_vector2['sample_index'].astype(np.int64, copy=False), spike_vector2['unit_index'].astype(np.int64, copy=False), max_shift, gaussian_std)


@numba.jit((numba.int64[:], numba.int64[:], numba.int64[:], numba.int64[:], numba.int32, numba.float32),
Expand Down Expand Up @@ -540,12 +595,8 @@ def compute_cross_shift(spike_times1, spike_labels1, spike_times2, spike_labels2
N = math.ceil(5 * gaussian_std)
gaussian = np.exp(-np.arange(-N, N+1)**2 / (2 * gaussian_std**2)) / (gaussian_std * math.sqrt(2*math.pi))

spike_trains1 = numba.typed.List()
spike_trains2 = numba.typed.List()
for unit1 in range(n_units1):
spike_trains1.append(spike_times1[spike_labels1 == unit1])
for unit2 in range(n_units2):
spike_trains2.append(spike_times2[spike_labels2 == unit2])
spike_trains1 = spike_vector_to_spike_trains(spike_times1, spike_labels1)
spike_trains2 = spike_vector_to_spike_trains(spike_times2, spike_labels2)

for unit1 in numba.prange(n_units1):
for unit2 in range(n_units2):
Expand Down Expand Up @@ -609,11 +660,11 @@ def _create_fft_gaussian(N: int, cutoff_freq: float) -> np.ndarray:
sigma = Utils.sampling_frequency / (2 * math.pi * cutoff_freq)
limit = int(round(6*sigma)) + 1
xaxis = np.arange(-limit, limit+1) / sigma
gaussian = scipy.stats.norm.pdf(xaxis) / sigma
gaussian = gaussian_pdf(xaxis) / sigma
return np.abs(np.fft.fft(gaussian, n=N))
else:
freq_axis = np.fft.fftfreq(N, d=1/Utils.sampling_frequency)
return scipy.stats.norm.pdf(freq_axis / cutoff_freq) * math.sqrt(2 * math.pi)
return gaussian_pdf(freq_axis / cutoff_freq) * math.sqrt(2 * math.pi)


def compute_correlogram_difference(auto_corr1: np.ndarray, auto_corr2: np.ndarray, cross_corr: np.ndarray, n1: int, n2: int) -> float:
Expand Down
6 changes: 3 additions & 3 deletions tests/modules/test_merge_sortings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def test_compute_graph(data: LussacData) -> None:
'refractory_period': (0.2, 1.0),
'similarity': {'min_similarity': 0.4},
'require_multiple_sortings_match': False,
'waveform_validation': {'filter': [150.0, 9_000.0], 'wvf_extraction': {}}
'waveform_validation': {'wvf_extraction': {'filter': [150.0, 9_000.0]}}
}

module.aggregated_wvf_extractor = module.extract_waveforms(sub_folder="graph", filter=p['waveform_validation']['filter'], sparse=False, **p['waveform_validation']['wvf_extraction'])
module.aggregated_wvf_extractor = module.extract_waveforms(sub_folder="graph", sparse=False, **p['waveform_validation']['wvf_extraction'])

graph = module._compute_graph(similarity_matrices, p)
assert graph.number_of_nodes() == 8
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_compute_difference(merge_sortings_module: MergeSortings) -> None:
for name2, sorting2 in sortings.items()} for name1, sorting1 in sortings.items()}
params = merge_sortings_module.update_params({})
params['waveform_validation']['wvf_extraction']['max_spikes_per_unit'] = 200
merge_sortings_module.aggregated_wvf_extractor = merge_sortings_module.extract_waveforms(sub_folder="compute_differences", filter=params['waveform_validation']['filter'], sparse=False, **params['waveform_validation']['wvf_extraction'])
merge_sortings_module.aggregated_wvf_extractor = merge_sortings_module.extract_waveforms(sub_folder="compute_differences", sparse=False, **params['waveform_validation']['wvf_extraction'])

# Test with empty graph
merge_sortings_module.compute_correlogram_difference(graph, cross_shifts, params['correlogram_validation'])
Expand Down
8 changes: 8 additions & 0 deletions tests/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from spikeinterface.curation.curation_tools import find_duplicated_spikes


def test_gaussian_pdf() -> None:
xaxis = np.arange(-10, 10, 0.01)
mu = 0.3
sigma = 1.2

assert np.allclose(utils.gaussian_pdf(xaxis, mu, sigma), scipy.stats.norm.pdf(xaxis, loc=mu, scale=sigma), atol=0, rtol=1e-10)


def test_filter_kwargs() -> None:
assert utils.filter_kwargs({}, test_flatten_dict) == {}
assert utils.filter_kwargs({'t_r': 2.0, 't_c': 1.0}, generate_spike_train) == {'t_r': 2.0}
Expand Down

0 comments on commit 3a1b717

Please sign in to comment.