Skip to content

Commit

Permalink
Merge pull request #2003 from samuelgarcia/fix_detect_bad_channels
Browse files Browse the repository at this point in the history
Improve detect_bad_channels
  • Loading branch information
alejoe91 authored Sep 19, 2023
2 parents 855a264 + 73395fb commit fc95465
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def detect_bad_channels(
nyquist_threshold=0.8,
direction="y",
chunk_duration_s=0.3,
num_random_chunks=10,
num_random_chunks=100,
welch_window_ms=10.0,
highpass_filter_cutoff=300,
neighborhood_r2_threshold=0.9,
Expand Down Expand Up @@ -81,9 +81,10 @@ def detect_bad_channels(
highpass_filter_cutoff : float
If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300
chunk_duration_s : float
Duration of each chunk, by default 0.3
Duration of each chunk, by default 0.5
num_random_chunks : int
Number of random chunks, by default 10
Number of random chunks, by default 100
Having many chunks is important for reproducibility.
welch_window_ms : float
Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms
neighborhood_r2_threshold : float, default 0.95
Expand Down Expand Up @@ -174,20 +175,18 @@ def detect_bad_channels(
channel_locations = recording.get_channel_locations()
dim = ["x", "y", "z"].index(direction)
assert dim < channel_locations.shape[1], f"Direction {direction} is wrong"
locs_depth = channel_locations[:, dim]
if np.array_equal(np.sort(locs_depth), locs_depth):
order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y"))
if np.all(np.diff(order_f) == 1):
# already ordered
order_f = None
order_r = None
else:
# sort by x, y to avoid ambiguity
order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y"))

# Create empty channel labels and fill with bad-channel detection estimate for each chunk
chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8)

for i, random_chunk in enumerate(random_data):
random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk
chunk_channel_labels[:, i] = detect_bad_channels_ibl(
random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk
chunk_labels = detect_bad_channels_ibl(
raw=random_chunk_sorted,
fs=recording.sampling_frequency,
psd_hf_threshold=psd_hf_threshold,
Expand All @@ -198,11 +197,10 @@ def detect_bad_channels(
nyquist_threshold=nyquist_threshold,
welch_window_ms=welch_window_ms,
)
chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels

# Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output.
mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False)
if order_r is not None:
mode_channel_labels = mode_channel_labels[order_r]

(bad_inds,) = np.where(mode_channel_labels != 0)
bad_channel_ids = recording.channel_ids[bad_inds]
Expand Down

0 comments on commit fc95465

Please sign in to comment.