Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch to skip chunks without peaks or if enough peaks have been detected #2011

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1720e15
Patch to skip chunks without peaks
yger Sep 19, 2023
2760279
Merge branch 'main' into skip_no_peaks
yger Sep 19, 2023
7f72b21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2023
40b092d
Patch
yger Sep 19, 2023
c7aef09
Patch
yger Sep 19, 2023
4bee4b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2023
f4accb7
Merge branch 'main' into skip_no_peaks
yger Sep 19, 2023
226b73a
Only for Spike and PeakRetriever
yger Sep 19, 2023
68d6017
Merge branch 'main' into skip_no_peaks
yger Sep 20, 2023
b226335
Merge branch 'main' into skip_no_peaks
yger Nov 15, 2023
d564309
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Nov 24, 2023
386a23e
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Nov 29, 2023
da5f816
Merge branch 'main' into skip_no_peaks
yger Jan 2, 2024
87a6606
Skipping the loop if not peaks
yger Jan 2, 2024
5272d2b
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 3, 2024
ad54cfa
Extension to deal with max_peaks
yger Jan 3, 2024
8c589f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2024
e9d456b
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 3, 2024
f9c3274
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 10, 2024
7250653
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 10, 2024
c4c147c
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 12, 2024
5950998
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 12, 2024
5b5a080
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 12, 2024
d4a7c0f
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 15, 2024
6d1838e
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 19, 2024
f80a1e9
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 19, 2024
d679506
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 22, 2024
971f4fd
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 29, 2024
5e2eae0
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Jan 30, 2024
03945d5
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Mar 11, 2024
54ced06
Merge branch 'SpikeInterface:main' into skip_no_peaks
yger Mar 12, 2024
6a6bcd7
Merge branch 'main' into skip_no_peaks
yger Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,14 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self):
def run(self, seed=None):
"""
Runs the defined jobs.
"""
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)
if seed is not None:
rng = np.random.default_rng(seed)
all_chunks = rng.permutation(all_chunks)

if self.handle_returns:
returns = []
Expand Down
89 changes: 68 additions & 21 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def __init__(self, recording, peaks):
def get_trace_margin(self):
return 0

def has_peaks(self, start_frame, end_frame, segment_index):
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
return i0 < i1

def get_dtype(self):
return base_peak_dtype

Expand Down Expand Up @@ -209,6 +215,12 @@ def __init__(
def get_trace_margin(self):
return 0

def has_peaks(self, start_frame, end_frame, segment_index):
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
return i0 < i1

def get_dtype(self):
return self._dtype

Expand Down Expand Up @@ -467,12 +479,14 @@ def run_node_pipeline(
nodes,
job_kwargs,
job_name="pipeline",
max_peaks=None,
mp_context=None,
gather_mode="memory",
gather_kwargs={},
squeeze_output=True,
folder=None,
names=None,
seed=None,
verbose=False,
):
"""
Expand All @@ -491,7 +505,7 @@ def run_node_pipeline(
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

init_args = (recording, nodes)
init_args = (recording, nodes, max_peaks)

processor = ChunkRecordingExecutor(
recording,
Expand All @@ -504,30 +518,51 @@ def run_node_pipeline(
**job_kwargs,
)

processor.run()
processor.run(seed=seed)

outs = gather_func.finalize_buffers(squeeze_output=squeeze_output)
return outs


def _init_peak_pipeline(recording, nodes):
def _init_peak_pipeline(recording, nodes, max_peaks):
# create a local dict per worker
worker_ctx = {}
worker_ctx["recording"] = recording
worker_ctx["nodes"] = nodes
worker_ctx["num_peaks"] = 0
worker_ctx["max_peaks"] = max_peaks
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)
return worker_ctx


def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx):
recording = worker_ctx["recording"]
max_margin = worker_ctx["max_margin"]
max_peaks = worker_ctx["max_peaks"]
nodes = worker_ctx["nodes"]
num_peaks = worker_ctx["num_peaks"]

recording_segment = recording._recording_segments[segment_index]
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True
)

if max_peaks is not None:
search_for_peaks = num_peaks < max_peaks
else:
search_for_peaks = True

if isinstance(nodes[0], PeakRetriever) or isinstance(nodes[0], SpikeRetriever):
chunk_has_peaks = nodes[0].has_peaks(start_frame, end_frame, segment_index)
else:
chunk_has_peaks = True

if search_for_peaks and chunk_has_peaks:
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True
)
traces_loaded = True
else:
peak_output = np.zeros(0, base_peak_dtype)
waveform_output, left_margin, right_margin = np.zeros((0, 0), dtype=recording.dtype), 0, 0
traces_loaded = False

# compute the graph
pipeline_outputs = {}
Expand All @@ -538,23 +573,32 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
parent_output = pipeline_outputs[parent]
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,)
node_input_args += parent_outputs_tuple
if isinstance(node, PeakDetector):
# to handle compatibility peak detector is a special case
# with specific margin
# TODO later when in master: change this later
extra_margin = max_margin - node.get_trace_margin()
if extra_margin:
trace_detection = traces_chunk[extra_margin:-extra_margin]
if traces_loaded:
if isinstance(node, PeakDetector):
# to handle compatibility peak detector is a special case
# with specific margin
# TODO later when in master: change this later
extra_margin = max_margin - node.get_trace_margin()
if extra_margin:
trace_detection = traces_chunk[extra_margin:-extra_margin]
else:
trace_detection = traces_chunk
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
worker_ctx["num_peaks"] += len(node_output[0])
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin)
worker_ctx["num_peaks"] += len(node_output[0])
else:
trace_detection = traces_chunk
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin)
# TODO later when in master: change the signature of all nodes (or maybe not!)
node_output = node.compute(traces_chunk, *node_input_args)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
node_output = node.compute(traces_chunk, *node_input_args)
if isinstance(node, PeakSource):
node_output = peak_output
else:
node_output = waveform_output

pipeline_outputs[node] = node_output

# propagate the output
Expand Down Expand Up @@ -582,6 +626,9 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
# we need to go back to absolut sample index
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin

else:
pipeline_outputs_tuple = tuple()

return pipeline_outputs_tuple


Expand Down
12 changes: 11 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@


def detect_peaks(
recording, method="by_channel", pipeline_nodes=None, gather_mode="memory", folder=None, names=None, **kwargs
recording,
method="by_channel",
pipeline_nodes=None,
gather_mode="memory",
folder=None,
names=None,
seed=None,
max_peaks=None,
**kwargs,
):
"""Peak detection based on threshold crossing in term of k x MAD.

Expand Down Expand Up @@ -123,6 +131,8 @@ def detect_peaks(
gather_mode=gather_mode,
squeeze_output=squeeze_output,
folder=folder,
seed=seed,
max_peaks=max_peaks,
names=names,
)
return outs
Expand Down
Loading