Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve detect_bad_channels #2003

Merged
merged 4 commits into from
Sep 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def detect_bad_channels(
nyquist_threshold=0.8,
direction="y",
chunk_duration_s=0.3,
num_random_chunks=10,
num_random_chunks=100,
welch_window_ms=10.0,
highpass_filter_cutoff=300,
neighborhood_r2_threshold=0.9,
Expand Down Expand Up @@ -81,9 +81,10 @@ def detect_bad_channels(
highpass_filter_cutoff : float
If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300
chunk_duration_s : float
Duration of each chunk, by default 0.3
Duration of each chunk, by default 0.5
num_random_chunks : int
Number of random chunks, by default 10
Number of random chunks, by default 100
Having many chunks is important for reproducibility.
welch_window_ms : float
Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms
neighborhood_r2_threshold : float, default 0.95
Expand Down Expand Up @@ -174,20 +175,18 @@ def detect_bad_channels(
channel_locations = recording.get_channel_locations()
dim = ["x", "y", "z"].index(direction)
assert dim < channel_locations.shape[1], f"Direction {direction} is wrong"
locs_depth = channel_locations[:, dim]
if np.array_equal(np.sort(locs_depth), locs_depth):
order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y"))
if np.all(np.diff(order_f) == 1):
# already ordered
order_f = None
order_r = None
else:
# sort by x, y to avoid ambiguity
order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y"))

# Create empty channel labels and fill with bad-channel detection estimate for each chunk
chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8)

for i, random_chunk in enumerate(random_data):
random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk
chunk_channel_labels[:, i] = detect_bad_channels_ibl(
random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk
chunk_labels = detect_bad_channels_ibl(
raw=random_chunk_sorted,
fs=recording.sampling_frequency,
psd_hf_threshold=psd_hf_threshold,
Expand All @@ -198,11 +197,10 @@ def detect_bad_channels(
nyquist_threshold=nyquist_threshold,
welch_window_ms=welch_window_ms,
)
chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels

# Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output.
mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False)
if order_r is not None:
mode_channel_labels = mode_channel_labels[order_r]

(bad_inds,) = np.where(mode_channel_labels != 0)
bad_channel_ids = recording.channel_ids[bad_inds]
Expand Down