diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 455edcfc80..8c5c62d568 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .recording_tools import get_channel_distances, get_noise_levels @@ -125,7 +127,7 @@ def unit_id_to_channel_indices(self): self._unit_id_to_channel_indices[unit_id] = channel_inds return self._unit_id_to_channel_indices - def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: + def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: """ Sparsify the waveforms according to a unit_id corresponding sparsity. @@ -159,7 +161,7 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: return sparsified_waveforms - def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: + def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: """ Densify sparse waveforms that were sparisified according to a unit's channel sparsity. @@ -199,7 +201,7 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: return waveforms.shape[-1] == self.num_channels - def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str) -> bool: + def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool: non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) return waveforms.shape[-1] == num_active_channels