Skip to content

Commit

Permalink
nodepipeline add skip_after_n_peaks option
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 6, 2024
1 parent b1d726a commit 6590e0f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 6 deletions.
22 changes: 19 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def run_node_pipeline(
folder=None,
names=None,
verbose=False,
skip_after_n_peaks=None,
):
"""
Common function to run pipeline with peak detector or already detected peak.
Expand All @@ -507,14 +508,19 @@ def run_node_pipeline(
job_kwargs = fix_job_kwargs(job_kwargs)
assert all(isinstance(node, PipelineNode) for node in nodes)

if skip_after_n_peaks is not None:
skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"]
else:
skip_after_n_peaks_per_worker = None

if gather_mode == "memory":
gather_func = GatherToMemory()
elif gather_mode == "npy":
gather_func = GatherToNpy(folder, names, **gather_kwargs)
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

init_args = (recording, nodes)
init_args = (recording, nodes, skip_after_n_peaks_per_worker)

processor = ChunkRecordingExecutor(
recording,
Expand All @@ -533,19 +539,22 @@ def run_node_pipeline(
return outs


def _init_peak_pipeline(recording, nodes):
def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker):
# create a local dict per worker
worker_ctx = {}
worker_ctx["recording"] = recording
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker
worker_ctx["num_peaks"] = 0
return worker_ctx


def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx):
recording = worker_ctx["recording"]
max_margin = worker_ctx["max_margin"]
nodes = worker_ctx["nodes"]
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]

recording_segment = recording._recording_segments[segment_index]
node0 = nodes[0]
Expand All @@ -557,7 +566,11 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
else:
# PeakDetector always need traces
load_trace_and_compute = True


if skip_after_n_peaks_per_worker is not None:
if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker:
load_trace_and_compute = False

if load_trace_and_compute:
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True
Expand Down Expand Up @@ -590,6 +603,9 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(traces_chunk, *node_input_args)
pipeline_outputs[node] = node_output

if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource):
worker_ctx["num_peaks"] += node_output[0].size

# propagate the output
pipeline_outputs_tuple = tuple()
for node in nodes:
Expand Down
37 changes: 34 additions & 3 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_run_node_pipeline(cache_folder_creation):
extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index")

peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype)
print(peaks.size)
# print(peaks.size)

peak_retriever = PeakRetriever(recording, peaks)
# this test when no spikes in last chunks
Expand Down Expand Up @@ -191,6 +191,37 @@ def test_run_node_pipeline(cache_folder_creation):
unpickled_node = pickle.loads(pickled_node)


def test_skip_after_n_peaks():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)

spikes = sorting.to_spike_vector()

# create peaks from spikes
sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory")
sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index")

peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype)
# print(peaks.size)

node0 = PeakRetriever(recording, peaks)
node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True)
nodes = [node0, node1]

skip_after_n_peaks = 30
some_amplitudes = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks)

assert some_amplitudes.size >= skip_after_n_peaks
assert some_amplitudes.size < spikes.size




if __name__ == "__main__":
folder = Path("./cache_folder/core")
test_run_node_pipeline(folder)
# folder = Path("./cache_folder/core")
# test_run_node_pipeline(folder)

test_skip_after_n_peaks()

1 comment on commit 6590e0f

@yger
Copy link
Collaborator

@yger yger commented on 6590e0f Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool, but you need to expose the chunks to the node pipeline, such that we can explore them randomly. Thanks a lot !!!

Please sign in to comment.