Skip to content

Commit

Permalink
Merge pull request #15 from SpikeInterface/spikingcircus
Browse files Browse the repository at this point in the history
pass extra arg to mountainsort5
  • Loading branch information
luiztauffer authored Mar 21, 2024
2 parents 22976c5 + 519cf05 commit 8a6681f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 12 deletions.
26 changes: 24 additions & 2 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,29 @@
from .logger import logger
from .global_params import JobKwargs
from .preprocessing import preprocess, PreprocessingParams
from .spikesorting import spikesort, SpikeSortingParams
from .spikesorting import (
spikesort,
SpikeSortingParams,
Kilosort25Model,
Kilosort3Model,
IronClustModel,
MountainSort5Model,
# SpykingCircus2Model,
)
from .postprocessing import postprocess, PostprocessingParams
from .curation import curate, CurationParams
from .visualization import visualize, VisualizationParams


sorter_model_map = {
"kilosort25": Kilosort25Model,
"kilosort3": Kilosort3Model,
"mountainsort5": MountainSort5Model,
# "spykingcircus2": SpykingCircus2Model,
"ironclust": IronClustModel,
}


def run_pipeline(
recording: si.BaseRecording,
scratch_folder: Path | str = Path("./scratch/"),
Expand Down Expand Up @@ -53,8 +70,13 @@ def run_pipeline(
if isinstance(preprocessing_params, dict):
preprocessing_params = PreprocessingParams(**preprocessing_params)
if isinstance(spikesorting_params, dict):
if spikesorting_params["sorter_name"] not in sorter_model_map:
raise ValueError(f"Sorter name {spikesorting_params['sorter_name']} not recognized")
if spikesorting_params["sorter_name"] == "mountainsort5":
spikesorting_params["sorter_kwargs"]["n_jobs_for_preprocessing"] = job_kwargs.n_jobs
spikesorting_params = SpikeSortingParams(
sorter_name=spikesorting_params["sorter_name"], sorter_kwargs=spikesorting_params["sorter_kwargs"]
sorter_name=spikesorting_params["sorter_name"],
sorter_kwargs=sorter_model_map[spikesorting_params["sorter_name"]](**spikesorting_params["sorter_kwargs"]),
)
if isinstance(postprocessing_params, dict):
postprocessing_params = PostprocessingParams(**postprocessing_params)
Expand Down
9 changes: 8 additions & 1 deletion src/spikeinterface_pipelines/spikesorting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from .spikesorting import spikesort
from .params import SpikeSortingParams
from .params import (
SpikeSortingParams,
Kilosort25Model,
Kilosort3Model,
IronClustModel,
MountainSort5Model,
# SpykingCircus2Model
)
75 changes: 66 additions & 9 deletions src/spikeinterface_pipelines/spikesorting/params.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pydantic import BaseModel, Field, ConfigDict
from typing import Union, List
from typing import Union, List, Optional
from enum import Enum


class SorterName(str, Enum):
ironclust = "ironclust"
kilosort25 = "kilosort2_5"
kilosort3 = "kilosort3"
mountainsort5 = "mountainsort5"
# spykingcircus2 = "spykingcircus2"
ironclust = "ironclust"


class Kilosort25Model(BaseModel):
Expand Down Expand Up @@ -36,7 +37,7 @@ class Kilosort25Model(BaseModel):
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
NT: int = Field(default=None, description="Batch size (if None it is automatically computed)")
NT: Optional[int] = Field(default=None, description="Batch size (if None it is automatically computed)")
AUCsplit: float = Field(
default=0.9,
description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step",
Expand All @@ -57,11 +58,6 @@ class Kilosort3Model(BaseModel):
pass


class IronClustModel(BaseModel):
model_config = ConfigDict(extra="forbid")
pass


class MountainSort5Model(BaseModel):
model_config = ConfigDict(extra="forbid")
scheme: str = Field(default="2", description="Sorting scheme", json_schema_extra={"options": ["1", "2", "3"]})
Expand All @@ -88,11 +84,72 @@ class MountainSort5Model(BaseModel):
freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency")
filter: bool = Field(default=True, description="Enable or disable filter")
whiten: bool = Field(default=True, description="Enable or disable whiten")
n_jobs_for_preprocessing: float = Field(default=0.8, description="Number of jobs for preprocessing")


## SpykingCircus2 - WIP
# class SpykingCircus2GeneralModel(BaseModel):
# ms_before: int = Field(default=2, description="ms before")
# ms_after: int = Field(default=2, description="ms after")
# radius_um: int = Field(default=100, description="radius um")


# class SpykingCircus2WaveformsModel(BaseModel):
# max_spikes_per_unit: int = Field(default=200, description="Max spikes per unit")
# overwrite: bool = Field(default=True, description="Overwrite")
# sparse: bool = Field(default=True, description="Sparse")
# method: str = Field(default="energy", description="Method")
# threshold: float = Field(default=0.25, description="Threshold")


# class SpykingCircus2FilteringModel(BaseModel):
# freq_min: int = Field(default=150, description="High-pass filter cutoff frequency")
# dtype: str = Field(default="float32", description="Data type")


# class SpykingCircus2DetectionModel(BaseModel):
# peak_sign: str = Field(default="neg", description="Peak sign")
# detect_threshold: int = Field(default=4, description="Detect threshold")


# class SpykingCircus2SelectionModel(BaseModel):
# method: str = Field(default="smart_sampling_amplitudes", description="Method")
# n_peaks_per_channel: int = Field(default=5000, description="Number of peaks per channel")
# min_n_peaks: int = Field(default=20000, description="Minimum number of peaks")
# select_per_channel: bool = Field(default=False, description="Select per channel")


# class SpykingCircus2ClusteringModel(BaseModel):
# legacy: bool = Field(default=False, description="Legacy")


# class SpykingCircus2CacheModel(BaseModel):
# mode: str = Field(default="memory", description="Mode")
# memory_limit: float = Field(default=0.5, description="Memory limit")
# delete_cache: bool = Field(default=True, description="Delete cache")


# class SpykingCircus2Model(BaseModel):
# model_config = ConfigDict(extra="forbid")
# general: SpykingCircus2GeneralModel = Field(default=SpykingCircus2GeneralModel(), description="General parameters")
# waveforms: SpykingCircus2WaveformsModel = Field(default=SpykingCircus2WaveformsModel(), description="Waveforms parameters")
# filtering: SpykingCircus2FilteringModel = Field(default=SpykingCircus2FilteringModel(), description="Filtering parameters")
# detection: SpykingCircus2DetectionModel = Field(default=SpykingCircus2DetectionModel(), description="Detection parameters")
# selection: SpykingCircus2SelectionModel = Field(default=SpykingCircus2SelectionModel(), description="Selection parameters")
# clustering: SpykingCircus2ClusteringModel = Field(default=SpykingCircus2ClusteringModel(), description="Clustering parameters")
# apply_preprocessing: bool = Field(default=True, description="Apply preprocessing")
# shared_memory: bool = Field(default=True, description="Shared memory")
# cache_preprocessing: SpykingCircus2CacheModel = Field(default=SpykingCircus2CacheModel(), description="Cache preprocessing")


class IronClustModel(BaseModel):
model_config = ConfigDict(extra="forbid")
pass


class SpikeSortingParams(BaseModel):
sorter_name: SorterName = Field(description="Name of the sorter to use.")
sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field(
sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, MountainSort5Model, IronClustModel] = Field(
description="Sorter specific kwargs.", union_mode="left_to_right"
)
spikesort_by_group: bool = Field(
Expand Down

0 comments on commit 8a6681f

Please sign in to comment.