Skip to content

Commit

Permalink
Merge pull request #16 from SpikeInterface/viz-modifications
Browse files Browse the repository at this point in the history
Visualization modifications
  • Loading branch information
alejoe91 authored Aug 27, 2024
2 parents d316ba8 + 742d9f1 commit 5e02a30
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 21 deletions.
26 changes: 13 additions & 13 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ def run_pipeline(
)
else:
logger.info("Skipping curation")

# Visualization
visualization_output = None
if run_visualization:
logger.info("Visualizing results")
visualization_output = visualize(
recording=recording_preprocessed,
sorting_curated=sorting_curated,
waveform_extractor=waveform_extractor,
visualization_params=visualization_params,
scratch_folder=scratch_folder,
results_folder=results_folder_visualization,
)
else:
logger.info("Skipping postprocessing")
waveform_extractor = None
Expand All @@ -146,17 +159,4 @@ def run_pipeline(
waveform_extractor = None
sorting_curated = None

# Visualization
visualization_output = None
if run_visualization:
logger.info("Visualizing results")
visualization_output = visualize(
recording=recording_preprocessed,
sorting_curated=sorting_curated,
waveform_extractor=waveform_extractor,
visualization_params=visualization_params,
scratch_folder=scratch_folder,
results_folder=results_folder_visualization,
)

return (recording_preprocessed, sorting, waveform_extractor, sorting_curated, visualization_output)
12 changes: 10 additions & 2 deletions src/spikeinterface_pipelines/postprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def postprocess(
# first extract some raw waveforms in memory to deduplicate based on peak alignment
wf_dedup_folder = tmp_folder / "waveforms_dense"
waveform_extractor_raw = si.extract_waveforms(
recording, sorting, folder=wf_dedup_folder, sparse=False, **postprocessing_params.waveforms_raw.model_dump()
recording,
sorting,
folder=wf_dedup_folder,
sparse=False,
overwrite=True,
**postprocessing_params.waveforms_raw.model_dump()
)

# de-duplication
Expand All @@ -66,7 +71,10 @@ def postprocess(

# this is a trick to make the postprocessed folder "self-contained
sorting_folder = results_folder / "sorting"
sorting_deduplicated = sorting_deduplicated.save(folder=sorting_folder)
sorting_deduplicated = sorting_deduplicated.save(
folder=sorting_folder,
overwrite=True,
)

# now extract waveforms on de-duplicated units
logger.info("[Postprocessing] \tSaving sparse de-duplicated waveform extractor folder")
Expand Down
3 changes: 0 additions & 3 deletions src/spikeinterface_pipelines/visualization/params.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from cProfile import label
from pydantic import BaseModel, Field
from typing import Literal, Union

from spikeinterface.widgets import sorting_summary


class TracesParams(BaseModel):
"""
Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface_pipelines/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def visualize(
logger.info(
"[Visualization] \tKachery client not found. Use `kachery-cloud-init` to initialize kachery client."
)
return
# return
visualization_params_dict = visualization_params.model_dump()
recording_params = visualization_params_dict["recording"]

Expand All @@ -80,6 +80,7 @@ def visualize(
peak_locations = waveform_extractor.load_extension("spike_locations").get_data()
peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data())
spike_locations_available = True

# otherwise detect peaks
if not spike_locations_available:
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
Expand Down Expand Up @@ -207,7 +208,7 @@ def visualize(
for prop in unit_table_properties:
if prop not in waveform_extractor.sorting.get_property_keys():
logger.info(
f"[Visualization] \tProperty {prop} not found in sorting object. " "Not adding to unit table"
f"[Visualization] \tProperty {prop} not found in sorting object. Not adding to unit table"
)
unit_table_properties.remove(prop)
v_sorting = sw.plot_sorting_summary(
Expand All @@ -225,7 +226,7 @@ def visualize(
print(f"\n{url}\n")
visualization_output["sorting_summary"] = url
except Exception as e:
logger.info("[Visualization] \tSortingview visualization failed with error:\n{e}")
logger.info(f"[Visualization] \tSortingview visualization failed with error:\n{e}")
else:
logger.info("[Visualization] \tNo units to visualize")

Expand Down

0 comments on commit 5e02a30

Please sign in to comment.