Skip to content

Commit

Permalink
Merge pull request #2209 from samuelgarcia/count_spike_array
Browse files Browse the repository at this point in the history
Add an option for count_num_spikes_per_unit
  • Loading branch information
alejoe91 authored Nov 22, 2023
2 parents 1cdcd5b + a6bd539 commit 5d7b64e
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 31 deletions.
8 changes: 4 additions & 4 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def do_count_event(sorting):
"""
import pandas as pd

return pd.Series(sorting.count_num_spikes_per_unit())
return pd.Series(sorting.count_num_spikes_per_unit(outputs="dict"))


def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids,
Expand Down Expand Up @@ -310,7 +310,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa

# ensure the number of match do not exceed the number of spike in train 2
# this is a simple way to handle corner cases for bursting in sorting1
spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values()))
spike_count2 = sorting2.count_num_spikes_per_unit(outputs="array")
spike_count2 = spike_count2[np.newaxis, :]
matching_matrix = np.clip(matching_matrix, None, spike_count2)

Expand Down Expand Up @@ -353,8 +353,8 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True
unit1_ids = np.array(sorting1.get_unit_ids())
unit2_ids = np.array(sorting2.get_unit_ids())

ev_counts1 = np.array(list(sorting1.count_num_spikes_per_unit().values()))
ev_counts2 = np.array(list(sorting2.count_num_spikes_per_unit().values()))
ev_counts1 = sorting1.count_num_spikes_per_unit(outputs="array")
ev_counts2 = sorting2.count_num_spikes_per_unit(outputs="array")
event_counts1 = pd.Series(ev_counts1, index=unit1_ids)
event_counts2 = pd.Series(ev_counts2, index=unit2_ids)

Expand Down
58 changes: 40 additions & 18 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,37 +267,59 @@ def get_total_num_spikes(self):
DeprecationWarning,
stacklevel=2,
)
return self.count_num_spikes_per_unit()
return self.count_num_spikes_per_unit(outputs="dict")

def count_num_spikes_per_unit(self) -> dict:
def count_num_spikes_per_unit(self, outputs="dict"):
"""
For each unit : get number of spikes across segments.
Parameters
----------
outputs: "dict" | "array", default: "dict"
Control the type of the returned object: a dict (keys are unit_ids) or an numpy array.
Returns
-------
dict
Dictionary with unit_ids as key and number of spikes as values
dict or numpy.array
Dict : Dictionary with unit_ids as key and number of spikes as values
Numpy array : array of size len(unit_ids) in the same order as unit_ids.
"""
num_spikes = {}
num_spikes = np.zeros(self.unit_ids.size, dtype="int64")

if self._cached_spike_trains is not None:
for unit_id in self.unit_ids:
n = 0
# speed strategy by order
# 1. if _cached_spike_trains have all units then use it
# 2. if _cached_spike_vector is not non use it
# 3. loop with get_unit_spike_train

# check if all spiketrains are cached
if len(self._cached_spike_trains) == self.get_num_segments():
all_spiketrain_are_cached = True
for segment_index in range(self.get_num_segments()):
if len(self._cached_spike_trains[segment_index]) != self.unit_ids.size:
all_spiketrain_are_cached = False
break
else:
all_spiketrain_are_cached = False

if all_spiketrain_are_cached or self._cached_spike_vector is None:
# case 1 or 3
for unit_index, unit_id in enumerate(self.unit_ids):
for segment_index in range(self.get_num_segments()):
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
n += st.size
num_spikes[unit_id] = n
else:
num_spikes[unit_index] += st.size
elif self._cached_spike_vector is not None:
# case 2
spike_vector = self.to_spike_vector()
unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True)
for unit_index, unit_id in enumerate(self.unit_ids):
if unit_index in unit_indices:
idx = np.argmax(unit_indices == unit_index)
num_spikes[unit_id] = counts[idx]
else: # This unit has no spikes, hence it's not in the counts array.
num_spikes[unit_id] = 0
num_spikes[unit_indices] = counts

return num_spikes
if outputs == "array":
return num_spikes
elif outputs == "dict":
num_spikes = dict(zip(self.unit_ids, num_spikes))
return num_spikes
else:
raise ValueError("count_num_spikes_per_unit() output must be 'dict' or 'array'")

def count_total_num_spikes(self) -> int:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def test_BaseSorting():
spikes = sorting.to_spike_vector(extremum_channel_inds={0: 15, 1: 5, 2: 18})
# print(spikes)

num_spikes_per_unit = sorting.count_num_spikes_per_unit()
num_spikes_per_unit = sorting.count_num_spikes_per_unit(outputs="dict")
num_spikes_per_unit = sorting.count_num_spikes_per_unit(outputs="array")
total_spikes = sorting.count_total_num_spikes()

# select units
Expand Down
7 changes: 3 additions & 4 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_potential_auto_merge(

# STEP 1 :
if "min_spikes" in steps:
num_spikes = np.array(list(sorting.count_num_spikes_per_unit().values()))
num_spikes = sorting.count_num_spikes_per_unit(outputs="array")
to_remove = num_spikes < minimum_spikes
pair_mask[to_remove, :] = False
pair_mask[:, to_remove] = False
Expand Down Expand Up @@ -255,17 +255,16 @@ def compute_correlogram_diff(

# Index of the middle of the correlograms.
m = correlograms_smoothed.shape[2] // 2
num_spikes = sorting.count_num_spikes_per_unit()
num_spikes = sorting.count_num_spikes_per_unit(outputs="array")

corr_diff = np.full((n, n), np.nan, dtype="float64")
for unit_ind1 in range(n):
for unit_ind2 in range(unit_ind1 + 1, n):
if not pair_mask[unit_ind1, unit_ind2]:
continue

unit_id1, unit_id2 = unit_ids[unit_ind1], unit_ids[unit_ind2]
num1, num2 = num_spikes[unit_ind1], num_spikes[unit_ind2]

num1, num2 = num_spikes[unit_id1], num_spikes[unit_id2]
# Weighted window (larger unit imposes its window).
win_size = int(round((num1 * win_sizes[unit_ind1] + num2 * win_sizes[unit_ind2]) / (num1 + num2)))
# Plage of indices where correlograms are inside the window.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/remove_redundant.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def remove_redundant_units(
else:
remove_unit_ids.append(u2)
elif remove_strategy == "max_spikes":
num_spikes = sorting.count_num_spikes_per_unit()
num_spikes = sorting.count_num_spikes_per_unit(outputs="dict")
for u1, u2 in redundant_unit_pairs:
if num_spikes[u1] < num_spikes[u2]:
remove_unit_ids.append(u1)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), uni
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit()
spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit(outputs="dict")
sorting = waveform_extractor.sorting
spikes = sorting.to_spike_vector(concatenated=False)

Expand Down Expand Up @@ -683,7 +683,7 @@ def compute_amplitude_cv_metrics(
sorting = waveform_extractor.sorting
total_duration = waveform_extractor.get_total_duration()
spikes = sorting.to_spike_vector()
num_spikes = sorting.count_num_spikes_per_unit()
num_spikes = sorting.count_num_spikes_per_unit(outputs="dict")
if unit_ids is None:
unit_ids = sorting.unit_ids

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_depths.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
unit_amplitudes = get_template_extremum_amplitude(we, peak_sign=peak_sign)
unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids])

num_spikes = np.array(list(we.sorting.count_num_spikes_per_unit().values()))
num_spikes = we.sorting.count_num_spikes_per_unit(outputs="array")

plot_data = dict(
unit_depths=unit_depths,
Expand Down

0 comments on commit 5d7b64e

Please sign in to comment.