Skip to content

Commit

Permalink
Merge pull request #14 from SpikeInterface/update-sorting-models
Browse files Browse the repository at this point in the history
update sorting models
  • Loading branch information
luiztauffer authored Mar 18, 2024
2 parents fb029c1 + f98541f commit 9f9ca21
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
12 changes: 7 additions & 5 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
from pathlib import Path
import re
from typing import Tuple
import spikeinterface as si

Expand All @@ -19,7 +18,7 @@ def run_pipeline(
results_folder: Path | str = Path("./results/"),
job_kwargs: JobKwargs | dict = JobKwargs(),
preprocessing_params: PreprocessingParams | dict = PreprocessingParams(),
spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(),
spikesorting_params: SpikeSortingParams | dict = dict(),
postprocessing_params: PostprocessingParams | dict = PostprocessingParams(),
curation_params: CurationParams | dict = CurationParams(),
visualization_params: VisualizationParams | dict = VisualizationParams(),
Expand Down Expand Up @@ -54,7 +53,10 @@ def run_pipeline(
if isinstance(preprocessing_params, dict):
preprocessing_params = PreprocessingParams(**preprocessing_params)
if isinstance(spikesorting_params, dict):
spikesorting_params = SpikeSortingParams(**spikesorting_params)
spikesorting_params = SpikeSortingParams(
sorter_name=spikesorting_params['sorter_name'],
sorter_kwargs=spikesorting_params['sorter_kwargs']
)
if isinstance(postprocessing_params, dict):
postprocessing_params = PostprocessingParams(**postprocessing_params)
if isinstance(curation_params, dict):
Expand Down Expand Up @@ -117,13 +119,13 @@ def run_pipeline(
else:
logger.info("Skipping postprocessing")
waveform_extractor = None

else:
logger.info("Skipping spike sorting")
sorting = None
waveform_extractor = None
sorting_curated = None


# Visualization
visualization_output = None
Expand Down
38 changes: 33 additions & 5 deletions src/spikeinterface_pipelines/spikesorting/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Union, List
from enum import Enum

Expand All @@ -11,6 +11,7 @@ class SorterName(str, Enum):


class Kilosort25Model(BaseModel):
model_config = ConfigDict(extra='forbid')
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections")
preclust_threshold: float = Field(
Expand Down Expand Up @@ -49,20 +50,47 @@ class Kilosort25Model(BaseModel):


class Kilosort3Model(BaseModel):
model_config = ConfigDict(extra='forbid')
pass


class IronClustModel(BaseModel):
model_config = ConfigDict(extra='forbid')
pass


class MountainSort5Model(BaseModel):
pass
model_config = ConfigDict(extra='forbid')
scheme: str = Field(
default='2',
description="Sorting scheme",
json_schema_extra={'options': ["1", "2", "3"]}
)
detect_threshold: float = Field(default=5.5, description="Threshold for spike detection")
detect_sign: int = Field(default=-1, description="Sign of the peak")
detect_time_radius_msec: float = Field(default=0.5, description="Time radius in milliseconds")
snippet_T1: int = Field(default=20, description="Snippet T1")
snippet_T2: int = Field(default=20, description="Snippet T2")
npca_per_channel: int = Field(default=3, description="Number of PCA per channel")
npca_per_subdivision: int = Field(default=10, description="Number of PCA per subdivision")
snippet_mask_radius: int = Field(default=250, description="Snippet mask radius")
scheme1_detect_channel_radius: int = Field(default=150, description="Scheme 1 detect channel radius")
scheme2_phase1_detect_channel_radius: int = Field(default=200, description="Scheme 2 phase 1 detect channel radius")
scheme2_detect_channel_radius: int = Field(default=50, description="Scheme 2 detect channel radius")
scheme2_max_num_snippets_per_training_batch: int = Field(default=200, description="Scheme 2 max number of snippets per training batch")
scheme2_training_duration_sec: int = Field(default=300, description="Scheme 2 training duration in seconds")
scheme2_training_recording_sampling_mode: str = Field(default='uniform', description="Scheme 2 training recording sampling mode")
scheme3_block_duration_sec: int = Field(default=1800, description="Scheme 3 block duration in seconds")
freq_min: int = Field(default=300, description="High-pass filter cutoff frequency")
freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency")
filter: bool = Field(default=True, description="Enable or disable filter")
whiten: bool = Field(default=True, description="Enable or disable whiten")


class SpikeSortingParams(BaseModel):
sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.")
spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.")
sorter_name: SorterName = Field(description="Name of the sorter to use.")
sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field(
default=Kilosort25Model(), description="Sorter specific kwargs."
description="Sorter specific kwargs.",
union_mode='left_to_right'
)
spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.")
9 changes: 1 addition & 8 deletions src/spikeinterface_pipelines/spikesorting/spikesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def spikesort(
recording: si.BaseRecording,
spikesorting_params: SpikeSortingParams = SpikeSortingParams(),
spikesorting_params: SpikeSortingParams,
scratch_folder: Path = Path("./scratch/"),
results_folder: Path = Path("./results/spikesorting/"),
) -> si.BaseSorting | None:
Expand All @@ -39,13 +39,6 @@ def spikesort(

try:
logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter")


## TEST ONLY - REMOVE LATER ##
# si.get_default_sorter_params('kilosort2_5')
# params_kilosort2_5 = {'do_correction': False}
## --------------------------##

if spikesorting_params.spikesort_by_group and len(np.unique(recording.get_channel_groups())) > 1:
logger.info(f"[Spikesorting] \tSorting by channel groups")
sorting = si.run_sorter_by_property(
Expand Down

0 comments on commit 9f9ca21

Please sign in to comment.