Skip to content

Commit

Permalink
Torch support for matching engines circus and OMP
Browse files Browse the repository at this point in the history
* Fixes

* Patches

* Fixes for SC2 and for split clustering

* debugging clustering

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Torch for convolutions

* Forcing data structures to be float32

* Device and wobble

* WIP

* Speeding up wobble

* WIP

* WIP

* Troch

* WIP torch

* WIP

* WIP

* Addition of a detection node for coherence

* Doc

* WIP

* Default params

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* Handling context with torch on the fly

* Dealing with torch

* Adding support for torch in matching engines

* Automatic handling of torch

* Default back

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adding gather_func to find_spikes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Gathering mode more explicit for matching

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* WIP

* Fixes for SC2

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* Simplifications

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Naming for Sam

* Optimize circus matching engine

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Optimizations

* Remove the limit to chunk sizes in circus-omp-svd

* WIP

* Wobble also

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Wobble also

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Oups

* WIP

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes

* Backward compatibility*

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Naming

* Cleaning

* Bringing back context for peak detectors

* Update src/spikeinterface/benchmark/benchmark_matching.py

* Update src/spikeinterface/sortingcomponents/matching/circus.py

* WIP

* Patch imports

* WIP

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixing tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* KSPeeler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Moving KS in a new PR

* Moving KS in a new PR

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Allow spawn and cuda for circus

* Add push_to_torch to allow pickling of objects

* Default

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleaning docs

* WIP

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Garcia Samuel <[email protected]>
  • Loading branch information
3 people authored Oct 15, 2024
1 parent 9fa21c9 commit 3e608c6
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 89 deletions.
1 change: 1 addition & 0 deletions src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def plot_performances_comparison(self, **kwargs):
def plot_collisions(self, case_keys=None, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"whitening": {"mode": "local", "regularize": True},
"whitening": {"mode": "local", "regularize": False},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {
"method": "uniform",
Expand Down Expand Up @@ -100,6 +100,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
except:
HAVE_HDBSCAN = False

try:
import torch
except ImportError:
HAVE_TORCH = False
print("spykingcircus2 could benefit from using torch. Consider installing it")

assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed"

# this is importanted only on demand because numba import are too heavy
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def main_function(cls, recording, peaks, params):
# SVD for time compression
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter))
few_wfs = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
)

wfs = few_wfs[:, :, 0]
Expand Down Expand Up @@ -141,7 +141,7 @@ def main_function(cls, recording, peaks, params):
all_pc_data = run_node_pipeline(
recording,
pipeline_nodes,
params["job_kwargs"],
job_kwargs,
job_name="extracting features",
)

Expand Down Expand Up @@ -176,7 +176,7 @@ def main_function(cls, recording, peaks, params):
_ = run_node_pipeline(
recording,
pipeline_nodes,
params["job_kwargs"],
job_kwargs,
job_name="extracting features",
gather_mode="npy",
gather_kwargs=dict(exist_ok=True),
Expand Down
134 changes: 90 additions & 44 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
("segment_index", "int64"),
]

try:
import torch
import torch.nn.functional as F

HAVE_TORCH = True
from torch.nn.functional import conv1d
except ImportError:
HAVE_TORCH = False

from .base import BaseTemplateMatching


Expand All @@ -43,9 +52,9 @@ def compress_templates(

temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False)
# Keep only the strongest components
temporal = temporal[:, :, :approx_rank]
singular = singular[:, :approx_rank]
spatial = spatial[:, :approx_rank, :]
temporal = temporal[:, :, :approx_rank].astype(np.float32)
singular = singular[:, :approx_rank].astype(np.float32)
spatial = spatial[:, :approx_rank, :].astype(np.float32)

if return_new_templates:
templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial)
Expand Down Expand Up @@ -107,18 +116,22 @@ class CircusOMPSVDPeeler(BaseTemplateMatching):
Parameters
----------
amplitude: tuple
amplitude : tuple
(Minimal, Maximal) amplitudes allowed for every template
max_failures: int
max_failures : int
Stopping criteria of the OMP algorithm, as number of retry while updating amplitudes
sparse_kwargs: dict
sparse_kwargs : dict
Parameters to extract a sparsity mask from the waveform_extractor, if not
already sparse.
rank: int, default: 5
rank : int, default: 5
Number of components used internally by the SVD
vicinity: int
vicinity : int
Size of the area surrounding a spike to perform modification (expressed in terms
of template temporal width)
engine : string in ["numpy", "torch", "auto"]. Default "auto"
The engine to use for the convolutions
torch_device : string in ["cpu", "cuda", None]. Default "cpu"
Controls torch device if the torch engine is selected
-----
"""

Expand Down Expand Up @@ -148,6 +161,8 @@ def __init__(
ignore_inds=[],
vicinity=2,
precomputed=None,
engine="numpy",
torch_device="cpu",
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None)
Expand All @@ -158,6 +173,19 @@ def __init__(
self.nafter = templates.nafter
self.sampling_frequency = recording.get_sampling_frequency()
self.vicinity = vicinity * self.num_samples
assert engine in ["numpy", "torch", "auto"], "engine should be numpy, torch or auto"
if engine == "auto":
if HAVE_TORCH:
self.engine = "torch"
else:
self.engine = "numpy"
else:
if engine == "torch":
assert HAVE_TORCH, "please install torch to use the torch engine"
self.engine = engine

assert torch_device in ["cuda", "cpu", None]
self.torch_device = torch_device

self.amplitudes = amplitudes
self.stop_criteria = stop_criteria
Expand All @@ -183,6 +211,7 @@ def __init__(
self.unit_overlaps_tables[i][self.unit_overlaps_indices[i]] = np.arange(len(self.unit_overlaps_indices[i]))

self.margin = 2 * self.num_samples
self.is_pushed = False

def _prepare_templates(self):

Expand Down Expand Up @@ -254,6 +283,14 @@ def _prepare_templates(self):
self.temporal = np.moveaxis(self.temporal, [0, 1, 2], [1, 2, 0])
self.singular = self.singular.T[:, :, np.newaxis]

def _push_to_torch(self):
if self.engine == "torch":
self.spatial = torch.as_tensor(self.spatial, device=self.torch_device)
self.singular = torch.as_tensor(self.singular, device=self.torch_device)
self.temporal = torch.as_tensor(self.temporal.copy(), device=self.torch_device).swapaxes(0, 1)
self.temporal = torch.flip(self.temporal, (2,))
self.is_pushed = True

def get_extra_outputs(self):
output = {}
for key in self._more_output_keys:
Expand All @@ -268,43 +305,52 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
import scipy
from scipy import ndimage

(potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32)
if not self.is_pushed:
self._push_to_torch()

(potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32)
(nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32)

overlaps_array = self.overlaps

omp_tol = np.finfo(np.float32).eps
num_samples = self.nafter + self.nbefore
neighbor_window = num_samples - 1
neighbor_window = self.num_samples - 1

if isinstance(self.amplitudes, list):
min_amplitude, max_amplitude = self.amplitudes
else:
min_amplitude, max_amplitude = self.amplitudes[:, 0], self.amplitudes[:, 1]
min_amplitude = min_amplitude[:, np.newaxis]
max_amplitude = max_amplitude[:, np.newaxis]

num_timesteps = len(traces)
if self.engine == "torch":
blank = np.zeros((neighbor_window, self.num_channels), dtype=np.float32)
traces = np.vstack((blank, traces, blank))
num_timesteps = traces.shape[0]
torch_traces = torch.as_tensor(traces.T[np.newaxis, :, :], device=self.torch_device)
num_templates, num_channels = self.temporal.shape[0], self.temporal.shape[1]
spatially_filtered_data = torch.matmul(self.spatial, torch_traces)
scaled_filtered_data = (spatially_filtered_data * self.singular).swapaxes(0, 1)
scaled_filtered_data_ = scaled_filtered_data.reshape(1, num_templates * num_channels, num_timesteps)
scalar_products = conv1d(scaled_filtered_data_, self.temporal, groups=num_templates, padding="valid")
scalar_products = scalar_products.cpu().numpy()[0, :, self.num_samples - 1 : -neighbor_window]
else:
num_timesteps = traces.shape[0]
num_peaks = num_timesteps - neighbor_window
conv_shape = (self.num_templates, num_peaks)
scalar_products = np.zeros(conv_shape, dtype=np.float32)
# Filter using overlap-and-add convolution
spatially_filtered_data = np.matmul(self.spatial, traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * self.singular
from scipy import signal

num_peaks = num_timesteps - num_samples + 1
conv_shape = (self.num_templates, num_peaks)
scalar_products = np.zeros(conv_shape, dtype=np.float32)
objective_by_rank = signal.oaconvolve(scaled_filtered_data, self.temporal, axes=2, mode="valid")
scalar_products += np.sum(objective_by_rank, axis=0)

num_peaks = scalar_products.shape[1]

# Filter using overlap-and-add convolution
if len(self.ignore_inds) > 0:
not_ignored = ~np.isin(np.arange(self.num_templates), self.ignore_inds)
spatially_filtered_data = np.matmul(self.spatial[:, not_ignored, :], traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * self.singular[:, not_ignored, :]
objective_by_rank = scipy.signal.oaconvolve(
scaled_filtered_data, self.temporal[:, not_ignored, :], axes=2, mode="valid"
)
scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0)
scalar_products[self.ignore_inds] = -np.inf
else:
spatially_filtered_data = np.matmul(self.spatial, traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * self.singular
objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, self.temporal, axes=2, mode="valid")
scalar_products += np.sum(objective_by_rank, axis=0)
not_ignored = ~np.isin(np.arange(self.num_templates), self.ignore_inds)

num_spikes = 0

Expand All @@ -322,7 +368,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
is_in_vicinity = np.zeros(0, dtype=np.int32)

if self.stop_criteria == "omp_min_sps":
stop_criteria = self.omp_min_sps * np.maximum(self.norms, np.sqrt(self.num_channels * num_samples))
stop_criteria = self.omp_min_sps * np.maximum(self.norms, np.sqrt(self.num_channels * self.num_samples))
elif self.stop_criteria == "max_failures":
num_valids = 0
nb_failures = self.max_failures
Expand Down Expand Up @@ -354,11 +400,11 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):

if num_selection > 0:
delta_t = selection[1] - peak_index
idx = np.flatnonzero((delta_t < num_samples) & (delta_t > -num_samples))
idx = np.flatnonzero((delta_t < self.num_samples) & (delta_t > -self.num_samples))
myline = neighbor_window + delta_t[idx]
myindices = selection[0, idx]

local_overlaps = overlaps_array[best_cluster_ind]
local_overlaps = self.overlaps[best_cluster_ind]
overlapping_templates = self.unit_overlaps_indices[best_cluster_ind]
table = self.unit_overlaps_tables[best_cluster_ind]

Expand Down Expand Up @@ -436,10 +482,10 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
for i in modified:
tmp_best, tmp_peak = sub_selection[:, i]
diff_amp = diff_amplitudes[i] * self.norms[tmp_best]
local_overlaps = overlaps_array[tmp_best]
local_overlaps = self.overlaps[tmp_best]
overlapping_templates = self.units_overlaps[tmp_best]
tmp = tmp_peak - neighbor_window
idx = [max(0, tmp), min(num_peaks, tmp_peak + num_samples)]
idx = [max(0, tmp), min(num_peaks, tmp_peak + self.num_samples)]
tdx = [idx[0] - tmp, idx[1] - tmp]
to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]]
scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add
Expand Down Expand Up @@ -500,27 +546,27 @@ class CircusPeeler(BaseTemplateMatching):
Parameters
----------
peak_sign: str
peak_sign : str
Sign of the peak (neg, pos, or both)
exclude_sweep_ms: float
exclude_sweep_ms : float
The number of samples before/after to classify a peak (should be low)
jitter: int
jitter : int
The number of samples considered before/after every peak to search for
matches
detect_threshold: int
detect_threshold : int
The detection threshold
noise_levels: array
noise_levels : array
The noise levels, for every channels
random_chunk_kwargs: dict
random_chunk_kwargs : dict
Parameters for computing noise levels, if not provided (sub optimal)
max_amplitude: float
max_amplitude : float
Maximal amplitude allowed for every template
min_amplitude: float
min_amplitude : float
Minimal amplitude allowed for every template
use_sparse_matrix_threshold: float
use_sparse_matrix_threshold : float
If density of the templates is below a given threshold, sparse matrix
are used (memory efficient)
sparse_kwargs: dict
sparse_kwargs : dict
Parameters to extract a sparsity mask from the waveform_extractor, if not
already sparse.
-----
Expand Down
Loading

0 comments on commit 3e608c6

Please sign in to comment.