Skip to content

Commit

Permalink
Merge branch 'main' into waveforms-percentiles
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Nov 22, 2023
2 parents 68b2383 + 3abc55e commit 4eb2e33
Show file tree
Hide file tree
Showing 21 changed files with 306 additions and 147 deletions.
8 changes: 4 additions & 4 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def do_count_event(sorting):
"""
import pandas as pd

return pd.Series(sorting.count_num_spikes_per_unit())
return pd.Series(sorting.count_num_spikes_per_unit(outputs="dict"))


def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids,
Expand Down Expand Up @@ -310,7 +310,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa

# ensure the number of match do not exceed the number of spike in train 2
# this is a simple way to handle corner cases for bursting in sorting1
spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values()))
spike_count2 = sorting2.count_num_spikes_per_unit(outputs="array")
spike_count2 = spike_count2[np.newaxis, :]
matching_matrix = np.clip(matching_matrix, None, spike_count2)

Expand Down Expand Up @@ -353,8 +353,8 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True
unit1_ids = np.array(sorting1.get_unit_ids())
unit2_ids = np.array(sorting2.get_unit_ids())

ev_counts1 = np.array(list(sorting1.count_num_spikes_per_unit().values()))
ev_counts2 = np.array(list(sorting2.count_num_spikes_per_unit().values()))
ev_counts1 = sorting1.count_num_spikes_per_unit(outputs="array")
ev_counts2 = sorting2.count_num_spikes_per_unit(outputs="array")
event_counts1 = pd.Series(ev_counts1, index=unit1_ids)
event_counts2 = pd.Series(ev_counts2, index=unit2_ids)

Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
get_chunk_with_margin,
order_channels_by_depth,
)
from .sorting_tools import spike_vector_to_spike_trains

from .waveform_tools import extract_waveforms_to_buffers
from .snippets_tools import snippets_from_sorting

Expand Down
93 changes: 72 additions & 21 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from .base import BaseExtractor, BaseSegment
from .sorting_tools import spike_vector_to_spike_trains
from .waveform_tools import has_exceeding_spikes


Expand Down Expand Up @@ -130,9 +131,11 @@ def get_unit_spike_train(
else:
spike_frames = self._cached_spike_trains[segment_index][unit_id]
if start_frame is not None:
spike_frames = spike_frames[spike_frames >= start_frame]
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
spike_frames = spike_frames[spike_frames < end_frame]
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
else:
segment = self._sorting_segments[segment_index]
spike_frames = segment.get_unit_spike_train(
Expand Down Expand Up @@ -267,37 +270,59 @@ def get_total_num_spikes(self):
DeprecationWarning,
stacklevel=2,
)
return self.count_num_spikes_per_unit()
return self.count_num_spikes_per_unit(outputs="dict")

def count_num_spikes_per_unit(self) -> dict:
def count_num_spikes_per_unit(self, outputs="dict"):
"""
For each unit : get number of spikes across segments.
Parameters
----------
outputs: "dict" | "array", default: "dict"
Control the type of the returned object: a dict (keys are unit_ids) or an numpy array.
Returns
-------
dict
Dictionary with unit_ids as key and number of spikes as values
dict or numpy.array
Dict : Dictionary with unit_ids as key and number of spikes as values
Numpy array : array of size len(unit_ids) in the same order as unit_ids.
"""
num_spikes = {}
num_spikes = np.zeros(self.unit_ids.size, dtype="int64")

# speed strategy by order
# 1. if _cached_spike_trains have all units then use it
# 2. if _cached_spike_vector is not non use it
# 3. loop with get_unit_spike_train

# check if all spiketrains are cached
if len(self._cached_spike_trains) == self.get_num_segments():
all_spiketrain_are_cached = True
for segment_index in range(self.get_num_segments()):
if len(self._cached_spike_trains[segment_index]) != self.unit_ids.size:
all_spiketrain_are_cached = False
break
else:
all_spiketrain_are_cached = False

if self._cached_spike_trains is not None:
for unit_id in self.unit_ids:
n = 0
if all_spiketrain_are_cached or self._cached_spike_vector is None:
# case 1 or 3
for unit_index, unit_id in enumerate(self.unit_ids):
for segment_index in range(self.get_num_segments()):
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
n += st.size
num_spikes[unit_id] = n
else:
num_spikes[unit_index] += st.size
elif self._cached_spike_vector is not None:
# case 2
spike_vector = self.to_spike_vector()
unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True)
for unit_index, unit_id in enumerate(self.unit_ids):
if unit_index in unit_indices:
idx = np.argmax(unit_indices == unit_index)
num_spikes[unit_id] = counts[idx]
else: # This unit has no spikes, hence it's not in the counts array.
num_spikes[unit_id] = 0
num_spikes[unit_indices] = counts

return num_spikes
if outputs == "array":
return num_spikes
elif outputs == "dict":
num_spikes = dict(zip(self.unit_ids, num_spikes))
return num_spikes
else:
raise ValueError("count_num_spikes_per_unit() output must be 'dict' or 'array'")

def count_total_num_spikes(self) -> int:
"""
Expand Down Expand Up @@ -409,7 +434,6 @@ def frame_slice(self, start_frame, end_frame, check_spike_frames=True):
def get_all_spike_trains(self, outputs="unit_id"):
"""
Return all spike trains concatenated.
This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
"""

Expand Down Expand Up @@ -445,6 +469,33 @@ def get_all_spike_trains(self, outputs="unit_id"):
spikes.append((spike_times, spike_labels))
return spikes

def precompute_spike_trains(self, from_spike_vector=None):
"""
Pre-computes and caches all spike trains for this sorting
Parameters
----------
from_spike_vector: None | bool, default: None
If None, then it is automatic depending on whether the spike vector is cached.
If True, will compute it from the spike vector.
If False, will call `get_unit_spike_train` for each segment for each unit.
"""
unit_ids = self.unit_ids

if from_spike_vector is None:
# if spike vector is cached then use it
from_spike_vector = self._cached_spike_vector is not None

if from_spike_vector:
self._cached_spike_trains = spike_vector_to_spike_trains(self.to_spike_vector(concatenated=False), unit_ids)

else:
for segment_index in range(self.get_num_segments()):
for unit_id in unit_ids:
self.get_unit_spike_train(unit_id, segment_index=segment_index, use_cache=True)

def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cache=True):
"""
Construct a unique structured numpy vector concatenating all spikes
Expand Down
74 changes: 37 additions & 37 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,43 +204,6 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsi
return worker_ctx


# used by write_binary_recording + ChunkRecordingExecutor
def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx):
# recover variables of the worker
recording = worker_ctx["recording"]
dtype = worker_ctx["dtype"]
byte_offset = worker_ctx["byte_offset"]
cast_unsigned = worker_ctx["cast_unsigned"]
file = worker_ctx["file_dict"][segment_index]

# Open the memmap
# What we need is the file_path
num_channels = recording.get_num_channels()
num_frames = recording.get_num_frames(segment_index=segment_index)
shape = (num_frames, num_channels)
dtype_size_bytes = np.dtype(dtype).itemsize
data_size_bytes = dtype_size_bytes * num_frames * num_channels

# Offset (The offset needs to be multiple of the page size)
# The mmap offset is associated to be as big as possible but still a multiple of the page size
# The array offset takes care of the reminder
mmap_offset, array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
mmmap_length = data_size_bytes + array_offset
memmap_obj = mmap.mmap(file.fileno(), length=mmmap_length, access=mmap.ACCESS_WRITE, offset=mmap_offset)

array = np.ndarray.__new__(np.ndarray, shape=shape, dtype=dtype, buffer=memmap_obj, order="C", offset=array_offset)
# apply function
traces = recording.get_traces(
start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned
)
if traces.dtype != dtype:
traces = traces.astype(dtype)
array[start_frame:end_frame, :] = traces

# Close the memmap
memmap_obj.flush()


def write_binary_recording(
recording,
file_paths,
Expand Down Expand Up @@ -312,6 +275,43 @@ def write_binary_recording(
executor.run()


# used by write_binary_recording + ChunkRecordingExecutor
def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx):
# recover variables of the worker
recording = worker_ctx["recording"]
dtype = worker_ctx["dtype"]
byte_offset = worker_ctx["byte_offset"]
cast_unsigned = worker_ctx["cast_unsigned"]
file = worker_ctx["file_dict"][segment_index]

# Open the memmap
# What we need is the file_path
num_channels = recording.get_num_channels()
num_frames = recording.get_num_frames(segment_index=segment_index)
shape = (num_frames, num_channels)
dtype_size_bytes = np.dtype(dtype).itemsize
data_size_bytes = dtype_size_bytes * num_frames * num_channels

# Offset (The offset needs to be multiple of the page size)
# The mmap offset is associated to be as big as possible but still a multiple of the page size
# The array offset takes care of the reminder
mmap_offset, array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
mmmap_length = data_size_bytes + array_offset
memmap_obj = mmap.mmap(file.fileno(), length=mmmap_length, access=mmap.ACCESS_WRITE, offset=mmap_offset)

array = np.ndarray.__new__(np.ndarray, shape=shape, dtype=dtype, buffer=memmap_obj, order="C", offset=array_offset)
# apply function
traces = recording.get_traces(
start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned
)
if traces.dtype != dtype:
traces = traces.astype(dtype)
array[start_frame:end_frame, :] = traces

# Close the memmap
memmap_obj.flush()


write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc)


Expand Down
59 changes: 13 additions & 46 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ def generate_recording(
set_probe: Optional[bool] = True,
ndim: Optional[int] = 2,
seed: Optional[int] = None,
mode: Literal["lazy", "legacy"] = "lazy",
) -> BaseRecording:
"""
Generate a recording object.
Useful for testing for testing API and algos.
Generate a lazy recording object.
Useful for testing API and algos.
Parameters
----------
Expand All @@ -49,13 +48,9 @@ def generate_recording(
Note that the number of segments is determined by the length of this list.
set_probe: bool, default: True
ndim : int, default: 2
The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probes.
The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probe.
seed : Optional[int]
A seed for the np.ramdom.default_rng function
mode: str ["lazy", "legacy"], default: "lazy".
"legacy": generate a NumpyRecording with white noise.
This mode is kept for backward compatibility and will be deprecated version 0.100.0.
"lazy": return a NoiseGeneratorRecording instance.
Returns
-------
Expand All @@ -64,26 +59,16 @@ def generate_recording(
"""
seed = _ensure_seed(seed)

if mode == "legacy":
warnings.warn(
"generate_recording() : mode='legacy' will be deprecated in version 0.100.0. Use mode='lazy' instead.",
DeprecationWarning,
)
recording = _generate_recording_legacy(num_channels, sampling_frequency, durations, seed)
elif mode == "lazy":
recording = NoiseGeneratorRecording(
num_channels=num_channels,
sampling_frequency=sampling_frequency,
durations=durations,
dtype="float32",
seed=seed,
strategy="tile_pregenerated",
# block size is fixed to one second
noise_block_size=int(sampling_frequency),
)

else:
raise ValueError("generate_recording() : wrong mode")
recording = NoiseGeneratorRecording(
num_channels=num_channels,
sampling_frequency=sampling_frequency,
durations=durations,
dtype="float32",
seed=seed,
strategy="tile_pregenerated",
# block size is fixed to one second
noise_block_size=int(sampling_frequency),
)

recording.annotate(is_filtered=True)

Expand All @@ -97,24 +82,6 @@ def generate_recording(
return recording


def _generate_recording_legacy(num_channels, sampling_frequency, durations, seed):
# legacy code to generate recotrding with random noise
rng = np.random.default_rng(seed=seed)

num_segments = len(durations)
num_timepoints = [int(sampling_frequency * d) for d in durations]

traces_list = []
for i in range(num_segments):
traces = rng.random(size=(num_timepoints[i], num_channels), dtype=np.float32)
times = np.arange(num_timepoints[i]) / sampling_frequency
traces += np.sin(2 * np.pi * 50 * times)[:, None]
traces_list.append(traces)
recording = NumpyRecording(traces_list, sampling_frequency)

return recording


def generate_sorting(
num_units=5,
sampling_frequency=30000.0, # in Hz
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,16 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame):
s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1])
self.spikes_in_seg = self.spikes[s0:s1]

start = 0 if start_frame is None else np.searchsorted(self.spikes_in_seg["sample_index"], start_frame)
end = (
len(self.spikes_in_seg)
if end_frame is None
else np.searchsorted(self.spikes_in_seg["sample_index"], end_frame)
)

unit_index = self.unit_ids.index(unit_id)
times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index]["sample_index"]
times = self.spikes_in_seg[start:end][self.spikes_in_seg[start:end]["unit_index"] == unit_index]["sample_index"]

if start_frame is not None:
times = times[times >= start_frame]
if end_frame is not None:
times = times[times < end_frame]
return times


Expand Down
Loading

0 comments on commit 4eb2e33

Please sign in to comment.