diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 2e22b1c..92cf3ac 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -138,6 +138,19 @@ def run_pipeline( ) else: logger.info("Skipping curation") + + # Visualization + visualization_output = None + if run_visualization: + logger.info("Visualizing results") + visualization_output = visualize( + recording=recording_preprocessed, + sorting_curated=sorting_curated, + waveform_extractor=waveform_extractor, + visualization_params=visualization_params, + scratch_folder=scratch_folder, + results_folder=results_folder_visualization, + ) else: logger.info("Skipping postprocessing") waveform_extractor = None @@ -148,17 +161,4 @@ def run_pipeline( waveform_extractor = None sorting_curated = None - # Visualization - visualization_output = None - if run_visualization: - logger.info("Visualizing results") - visualization_output = visualize( - recording=recording_preprocessed, - sorting_curated=sorting_curated, - waveform_extractor=waveform_extractor, - visualization_params=visualization_params, - scratch_folder=scratch_folder, - results_folder=results_folder_visualization, - ) - return (recording_preprocessed, sorting, waveform_extractor, sorting_curated, visualization_output) diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index 9502d33..cb5ef26 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -47,7 +47,12 @@ def postprocess( # first extract some raw waveforms in memory to deduplicate based on peak alignment wf_dedup_folder = tmp_folder / "waveforms_dense" waveform_extractor_raw = si.extract_waveforms( - recording, sorting, folder=wf_dedup_folder, sparse=False, **postprocessing_params.waveforms_raw.model_dump() + recording, + sorting, + folder=wf_dedup_folder, + sparse=False, + overwrite=True, + **postprocessing_params.waveforms_raw.model_dump() ) # de-duplication @@ -66,7 +71,10 @@ def postprocess( # this is a trick to make the postprocessed folder "self-contained sorting_folder = results_folder / "sorting" - sorting_deduplicated = sorting_deduplicated.save(folder=sorting_folder) + sorting_deduplicated = sorting_deduplicated.save( + folder=sorting_folder, + overwrite=True, + ) # now extract waveforms on de-duplicated units logger.info("[Postprocessing] \tSaving sparse de-duplicated waveform extractor folder") diff --git a/src/spikeinterface_pipelines/visualization/params.py b/src/spikeinterface_pipelines/visualization/params.py index a6b1e9d..708d20c 100644 --- a/src/spikeinterface_pipelines/visualization/params.py +++ b/src/spikeinterface_pipelines/visualization/params.py @@ -1,9 +1,6 @@ -from cProfile import label from pydantic import BaseModel, Field from typing import Literal, Union -from spikeinterface.widgets import sorting_summary - class TracesParams(BaseModel): """ diff --git a/src/spikeinterface_pipelines/visualization/visualization.py b/src/spikeinterface_pipelines/visualization/visualization.py index f228cff..07995c1 100644 --- a/src/spikeinterface_pipelines/visualization/visualization.py +++ b/src/spikeinterface_pipelines/visualization/visualization.py @@ -61,7 +61,7 @@ def visualize( logger.info( "[Visualization] \tKachery client not found. Use `kachery-cloud-init` to initialize kachery client." ) - return + # return visualization_params_dict = visualization_params.model_dump() recording_params = visualization_params_dict["recording"] @@ -80,6 +80,7 @@ def visualize( peak_locations = waveform_extractor.load_extension("spike_locations").get_data() peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data()) spike_locations_available = True + # otherwise detect peaks if not spike_locations_available: from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline @@ -207,7 +208,7 @@ def visualize( for prop in unit_table_properties: if prop not in waveform_extractor.sorting.get_property_keys(): logger.info( - f"[Visualization] \tProperty {prop} not found in sorting object. " "Not adding to unit table" + f"[Visualization] \tProperty {prop} not found in sorting object. Not adding to unit table" ) unit_table_properties.remove(prop) v_sorting = sw.plot_sorting_summary( @@ -225,7 +226,7 @@ def visualize( print(f"\n{url}\n") visualization_output["sorting_summary"] = url except Exception as e: - logger.info("[Visualization] \tSortingview visualization failed with error:\n{e}") + logger.info(f"[Visualization] \tSortingview visualization failed with error:\n{e}") else: logger.info("[Visualization] \tNo units to visualize")