Skip to content

Commit

Permalink
Merge branch 'main' into gt_study
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Oct 26, 2023
2 parents dba32ed + ef095c2 commit bf5b0a6
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 141 deletions.
4 changes: 3 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

# TODO: "amplitude" ???
# handle amplitude
for i, peak in enumerate(local_peaks):
local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]]

return (local_peaks,)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def compute(self, traces, peaks, waveforms):
def test_run_node_pipeline():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)

spikes = sorting.to_spike_vector()

Expand Down Expand Up @@ -104,7 +105,8 @@ def test_run_node_pipeline():
AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])
if loop == 0:
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
Expand Down Expand Up @@ -132,7 +134,6 @@ def test_run_node_pipeline():
# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

from .isi import (
ISIHistogramsCalculator,
compute_isi_histograms_from_spiketrain,
compute_isi_histograms,
compute_isi_histograms_numpy,
compute_isi_histograms_numba,
Expand Down
88 changes: 14 additions & 74 deletions src/spikeinterface/postprocessing/isi.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,61 +65,6 @@ def get_extension_function():
WaveformExtractor.register_extension(ISIHistogramsCalculator)


def compute_isi_histograms_from_spiketrain(spike_train: np.ndarray, max_time: int, bin_size: int, sampling_f: float):
"""
Computes the Inter-Spike Intervals histogram from a given spike train.
This implementation only works if you have numba installed, to accelerate the
computation time.
Parameters
----------
spike_train: np.ndarray
The ordered spike train to compute the ISI.
max_time: int
Compute the ISI from 0 to +max_time (in sampling time).
bin_size: int
Size of a bin (in sampling time).
sampling_f: float
Sampling rate/frequency (in Hz).
Returns
-------
tuple (ISI, bins)
ISI: np.ndarray[int64]
The computed ISI histogram.
bins: np.ndarray[float64]
The bins for the ISI histogram.
"""
if not HAVE_NUMBA:
print("Error: numba is not installed.")
print("compute_ISI_from_spiketrain cannot run without numba.")
return 0

return _compute_isi_histograms_from_spiketrain(spike_train.astype(np.int64), max_time, bin_size, sampling_f)


if HAVE_NUMBA:

@numba.jit((numba.int64[::1], numba.int32, numba.int32, numba.float32), nopython=True, nogil=True, cache=True)
def _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f):
n_bins = int(max_time / bin_size)

bins = np.arange(0, max_time + bin_size, bin_size) * 1e3 / sampling_f
ISI = np.zeros(n_bins, dtype=np.int64)

for i in range(1, len(spike_train)):
diff = spike_train[i] - spike_train[i - 1]

if diff >= max_time:
continue

bin = int(diff / bin_size)
ISI[bin] += 1

return ISI, bins


def compute_isi_histograms(
waveform_or_sorting_extractor,
load_if_exists=False,
Expand All @@ -140,7 +85,7 @@ def compute_isi_histograms(
bin_ms : float, optional
The bin size in ms, by default 1.0.
method : str, optional
"auto" | "numpy" | "numba". If _auto" and numba is installed, numba is used, by default "auto"
"auto" | "numpy" | "numba". If "auto" and numba is installed, numba is used, by default "auto"
Returns
-------
Expand Down Expand Up @@ -191,21 +136,19 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float
"""
fs = sorting.get_sampling_frequency()
num_units = len(sorting.unit_ids)

assert bin_ms * 1e-3 >= 1 / fs, f"bin size must be larger than the sampling period {1e3 / fs}"
assert bin_ms <= window_ms
window_size = int(round(fs * window_ms * 1e-3))
bin_size = int(round(fs * bin_ms * 1e-3))
window_size -= window_size % bin_size
num_bins = int(window_size / bin_size)
assert num_bins >= 1

ISIs = np.zeros((num_units, num_bins), dtype=np.int64)
bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs
ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64)

# TODO: There might be a better way than a double for loop?
for i, unit_id in enumerate(sorting.unit_ids):
for seg_index in range(sorting.get_num_segments()):
spike_train = sorting.get_unit_spike_train(unit_id, segment_index=seg_index)
ISI = np.histogram(np.diff(spike_train), bins=num_bins, range=(0, window_size - 1))[0]
ISI = np.histogram(np.diff(spike_train), bins=bins)[0]
ISIs[i] += ISI

return ISIs, bins
Expand All @@ -224,18 +167,18 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float

assert HAVE_NUMBA
fs = sorting.get_sampling_frequency()
assert bin_ms * 1e-3 >= 1 / fs, f"the bin_ms must be larger than the sampling period: {1e3 / fs}"
assert bin_ms <= window_ms
num_units = len(sorting.unit_ids)

window_size = int(round(fs * window_ms * 1e-3))
bin_size = int(round(fs * bin_ms * 1e-3))
window_size -= window_size % bin_size
num_bins = int(window_size / bin_size)
assert num_bins >= 1

bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs
spikes = sorting.to_spike_vector(concatenated=False)

ISIs = np.zeros((num_units, num_bins), dtype=np.int64)
ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64)

for seg_index in range(sorting.get_num_segments()):
spike_times = spikes[seg_index]["sample_index"].astype(np.int64)
Expand All @@ -245,9 +188,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float
ISIs,
spike_times,
spike_labels,
window_size,
bin_size,
fs,
bins,
)

return ISIs, bins
Expand All @@ -256,16 +197,15 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float
if HAVE_NUMBA:

@numba.jit(
(numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.int32, numba.int32, numba.float32),
(numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.float64[::1]),
nopython=True,
nogil=True,
cache=True,
parallel=True,
)
def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, max_time, bin_size, sampling_f):
def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins):
n_units = ISIs.shape[0]

for i in numba.prange(n_units):
units_loop = numba.prange(n_units) if n_units > 300 else range(n_units)
for i in units_loop:
spike_train = spike_trains[spike_clusters == i]

ISIs[i] += _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f)[0]
ISIs[i] += np.histogram(np.diff(spike_train), bins=bins)[0]
55 changes: 50 additions & 5 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift

from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension
from spikeinterface.core.node_pipeline import SpikeRetriever


class SpikeLocationsCalculator(BaseWaveformExtractorExtension):
Expand All @@ -25,8 +26,21 @@ def __init__(self, waveform_extractor):
extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index")
self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)

def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}):
params = dict(ms_before=ms_before, ms_after=ms_after, method=method)
def _set_params(
self,
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg",
),
method="center_of_mass",
method_kwargs={},
):
params = dict(
ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method
)
params.update(**method_kwargs)
return params

Expand All @@ -44,13 +58,22 @@ def _run(self, **job_kwargs):
uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate
spike locations.
"""
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source

job_kwargs = fix_job_kwargs(job_kwargs)

we = self.waveform_extractor

spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")

params = self._params.copy()
spike_retriver_kwargs = params.pop("spike_retriver_kwargs")

spike_retriever = SpikeRetriever(
we.recording, we.sorting, extremum_channel_inds=extremum_channel_inds, **spike_retriver_kwargs
)
spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs)

self._extension_data["spike_locations"] = spike_locations

def get_data(self, outputs="concatenated"):
Expand Down Expand Up @@ -101,6 +124,11 @@ def compute_spike_locations(
load_if_exists=False,
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg",
),
method="center_of_mass",
method_kwargs={},
outputs="concatenated",
Expand All @@ -119,6 +147,17 @@ def compute_spike_locations(
The left window, before a peak, in milliseconds.
ms_after : float
The right window, after a peak, in milliseconds.
spike_retriver_kwargs: dict
A dictionary to control the behavior for getting the maximum channel for each spike.
This dictionary contains:
* channel_from_template: bool, default True
For each spike is the maximum channel computed from template or re estimated at every spikes.
channel_from_template = True is old behavior but less acurate
channel_from_template = False is slower but more accurate
* radius_um: float, default 50
In case channel_from_template=False, this is the radius to get the true peak.
* peak_sign="neg"
In case channel_from_template=False, this is the peak sign.
method : str
'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution'
method_kwargs : dict
Expand All @@ -138,7 +177,13 @@ def compute_spike_locations(
slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name)
else:
slc = SpikeLocationsCalculator(waveform_extractor)
slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs)
slc.set_params(
ms_before=ms_before,
ms_after=ms_after,
spike_retriver_kwargs=spike_retriver_kwargs,
method=method,
method_kwargs=method_kwargs,
)
slc.run(**job_kwargs)

locs = slc.get_data(outputs=outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes
extension_class = SpikeLocationsCalculator
extension_data_names = ["spike_locations"]
extension_function_kwargs_list = [
dict(method="center_of_mass", chunk_size=10000, n_jobs=1),
dict(
method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True)
),
dict(
method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False)
),
dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"),
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def test_whiten():

print(rec.get_channel_locations())
random_chunk_kwargs = {}
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
print(W)
print(M)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25, eps=1e-8)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25)
# W must be sparse
np.sum(W == 0) == 6

Expand Down
Loading

0 comments on commit bf5b0a6

Please sign in to comment.