Skip to content

Commit

Permalink
Merge pull request #2250 from alejoe91/tip-bottom-option-rm-channels
Browse files Browse the repository at this point in the history
Add `outside_channels_location` option in `detect_bad_channels`
  • Loading branch information
samuelgarcia authored Dec 1, 2023
2 parents 78761bc + 933c160 commit 217eec1
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 29 deletions.
75 changes: 51 additions & 24 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
from __future__ import annotations
import warnings

import numpy as np
from typing import Literal

from .filter import highpass_filter
from ..core import get_random_data_chunks, order_channels_by_depth
from ..core import get_random_data_chunks, order_channels_by_depth, BaseRecording


def detect_bad_channels(
recording,
method="coherence+psd",
std_mad_threshold=5,
psd_hf_threshold=0.02,
dead_channel_threshold=-0.5,
noisy_channel_threshold=1.0,
outside_channel_threshold=-0.75,
n_neighbors=11,
nyquist_threshold=0.8,
direction="y",
chunk_duration_s=0.3,
num_random_chunks=100,
welch_window_ms=10.0,
highpass_filter_cutoff=300,
neighborhood_r2_threshold=0.9,
neighborhood_r2_radius_um=30.0,
seed=None,
recording: BaseRecording,
method: str = "coherence+psd",
std_mad_threshold: float = 5,
psd_hf_threshold: float = 0.02,
dead_channel_threshold: float = -0.5,
noisy_channel_threshold: float = 1.0,
outside_channel_threshold: float = -0.75,
outside_channels_location: Literal["top", "bottom", "both"] = "top",
n_neighbors: int = 11,
nyquist_threshold: float = 0.8,
direction: Literal["x", "y", "z"] = "y",
chunk_duration_s: float = 0.3,
num_random_chunks: int = 100,
welch_window_ms: float = 10.0,
highpass_filter_cutoff: float = 300,
neighborhood_r2_threshold: float = 0.9,
neighborhood_r2_radius_um: float = 30.0,
seed: int | None = None,
):
"""
Perform bad channel detection.
Expand Down Expand Up @@ -65,6 +68,11 @@ def detect_bad_channels(
outside_channel_threshold (coeherence+psd) : float, default: -0.75
Threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location (coeherence+psd) : "top" | "bottom" | "both", default: "top"
Location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors (coeherence+psd) : int, default: 11
Number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold (coeherence+psd) : float, default: 0.8
Expand Down Expand Up @@ -190,6 +198,7 @@ def detect_bad_channels(
n_neighbors=n_neighbors,
nyquist_threshold=nyquist_threshold,
welch_window_ms=welch_window_ms,
outside_channels_location=outside_channels_location,
)
chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels

Expand Down Expand Up @@ -275,6 +284,7 @@ def detect_bad_channels_ibl(
n_neighbors=11,
nyquist_threshold=0.8,
welch_window_ms=0.3,
outside_channels_location="top",
):
"""
Bad channels detection for Neuropixel probes developed by IBL
Expand All @@ -300,6 +310,11 @@ def detect_bad_channels_ibl(
Threshold on Nyquist frequency to calculate HF noise band
welch_window_ms: float, default: 0.3
Window size for the scipy.signal.welch that will be converted to nperseg
outside_channels_location : "top" | "bottom" | "both", default: "top"
Location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
Returns
-------
Expand Down Expand Up @@ -332,12 +347,24 @@ def detect_bad_channels_ibl(
ichannels[inoisy] = 2

# the channels outside of the brains are the contiguous channels below the threshold on the trend coherency
# the chanels outide need to be at either extremes of the probe
ioutside = np.where(xcorr_distant < outside_channel_thr)[0]
if ioutside.size > 0 and (ioutside[-1] == (nc - 1) or ioutside[0] == 0):
a = np.cumsum(np.r_[0, np.diff(ioutside) - 1])
ioutside = ioutside[a == np.max(a)]
ichannels[ioutside] = 3
# the chanels outside need to be at the extreme of the probe
(ioutside,) = np.where(xcorr_distant < outside_channel_thr)
a = np.cumsum(np.r_[0, np.diff(ioutside) - 1])
if ioutside.size > 0:
if outside_channels_location == "top":
# channels are sorted bottom to top, so the last channel needs to be (nc - 1)
if ioutside[-1] == (nc - 1):
ioutside = ioutside[(a == np.max(a)) & (a > 0)]
ichannels[ioutside] = 3
elif outside_channels_location == "bottom":
# outside channels are at the bottom of the probe, so the first channel needs to be 0
if ioutside[0] == 0:
ioutside = ioutside[(a == np.min(a)) & (a < np.max(a))]
ichannels[ioutside] = 3
else: # both extremes are considered
if ioutside[-1] == (nc - 1) or ioutside[0] == 0:
ioutside = ioutside[(a == np.max(a)) | (a == np.min(a))]
ichannels[ioutside] = 3

return ichannels

Expand Down
56 changes: 51 additions & 5 deletions src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
HAVE_NPIX = False


def test_remove_bad_channels_std_mad():
def test_detect_bad_channels_std_mad():
num_channels = 4
sampling_frequency = 30000.0
durations = [10.325, 3.5]
Expand Down Expand Up @@ -60,9 +60,48 @@ def test_remove_bad_channels_std_mad():
), "wrong channels locations."


@pytest.mark.parametrize("outside_channels_location", ["bottom", "top", "both"])
def test_detect_bad_channels_extremes(outside_channels_location):
num_channels = 64
sampling_frequency = 30000.0
durations = [20]
num_out_channels = 10

num_segments = len(durations)
num_timepoints = [int(sampling_frequency * d) for d in durations]

traces_list = []
for i in range(num_segments):
traces = np.random.randn(num_timepoints[i], num_channels).astype("float32")
# extreme channels are "out"
traces[:, :num_out_channels] *= 0.05
traces[:, -num_out_channels:] *= 0.05
traces_list.append(traces)

rec = NumpyRecording(traces_list, sampling_frequency)
rec.set_channel_gains(1)
rec.set_channel_offsets(0)

probe = generate_linear_probe(num_elec=num_channels)
probe.set_device_channel_indices(np.arange(num_channels))
rec.set_probe(probe, in_place=True)

bad_channel_ids, bad_labels = detect_bad_channels(
rec, method="coherence+psd", outside_channels_location=outside_channels_location
)
if outside_channels_location == "top":
assert np.array_equal(bad_channel_ids, rec.channel_ids[-num_out_channels:])
elif outside_channels_location == "bottom":
assert np.array_equal(bad_channel_ids, rec.channel_ids[:num_out_channels])
elif outside_channels_location == "both":
assert np.array_equal(
bad_channel_ids, np.concatenate((rec.channel_ids[:num_out_channels], rec.channel_ids[-num_out_channels:]))
)


@pytest.mark.skipif(not HAVE_NPIX, reason="ibl-neuropixel is not installed")
@pytest.mark.parametrize("num_channels", [32, 64, 384])
def test_remove_bad_channels_ibl(num_channels):
def test_detect_bad_channels_ibl(num_channels):
"""
Cannot test against DL datasets because they are too short
and need to control the PSD scaling. Here generate a dataset
Expand Down Expand Up @@ -121,7 +160,9 @@ def test_remove_bad_channels_ibl(num_channels):
traces_uV = random_chunk.T
traces_V = traces_uV * 1e-6
channel_flags, _ = neurodsp.voltage.detect_bad_channels(
traces_V, recording.get_sampling_frequency(), psd_hf_threshold=psd_cutoff
traces_V,
recording.get_sampling_frequency(),
psd_hf_threshold=psd_cutoff,
)
channel_flags_ibl[:, i] = channel_flags

Expand Down Expand Up @@ -209,5 +250,10 @@ def add_dead_channels(recording, is_dead):


if __name__ == "__main__":
test_remove_bad_channels_std_mad()
test_remove_bad_channels_ibl(num_channels=384)
# test_detect_bad_channels_std_mad()
test_detect_bad_channels_ibl(num_channels=32)
test_detect_bad_channels_ibl(num_channels=64)
test_detect_bad_channels_ibl(num_channels=384)
# test_detect_bad_channels_extremes("top")
# test_detect_bad_channels_extremes("bottom")
# test_detect_bad_channels_extremes("both")

0 comments on commit 217eec1

Please sign in to comment.