Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option for count_num_spikes_per_unit #2209

Merged
merged 9 commits into from
Nov 22, 2023
8 changes: 4 additions & 4 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def do_count_event(sorting):
"""
import pandas as pd

return pd.Series(sorting.count_num_spikes_per_unit())
return pd.Series(sorting.count_num_spikes_per_unit(outputs="dict"))


def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids,
Expand Down Expand Up @@ -310,7 +310,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa

# ensure the number of match do not exceed the number of spike in train 2
# this is a simple way to handle corner cases for bursting in sorting1
spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values()))
spike_count2 = sorting2.count_num_spikes_per_unit(outputs="array")
spike_count2 = spike_count2[np.newaxis, :]
matching_matrix = np.clip(matching_matrix, None, spike_count2)

Expand Down Expand Up @@ -353,8 +353,8 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True
unit1_ids = np.array(sorting1.get_unit_ids())
unit2_ids = np.array(sorting2.get_unit_ids())

ev_counts1 = np.array(list(sorting1.count_num_spikes_per_unit().values()))
ev_counts2 = np.array(list(sorting2.count_num_spikes_per_unit().values()))
ev_counts1 = sorting1.count_num_spikes_per_unit(outputs="array")
ev_counts2 = sorting2.count_num_spikes_per_unit(outputs="array")
event_counts1 = pd.Series(ev_counts1, index=unit1_ids)
event_counts2 = pd.Series(ev_counts2, index=unit2_ids)

Expand Down
59 changes: 41 additions & 18 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,37 +267,60 @@ def get_total_num_spikes(self):
DeprecationWarning,
stacklevel=2,
)
return self.count_num_spikes_per_unit()
return self.count_num_spikes_per_unit(outputs="dict")

def count_num_spikes_per_unit(self) -> dict:
def count_num_spikes_per_unit(self, outputs="dict"):
"""
For each unit : get number of spikes across segments.

Parameters
----------
outputs: "dict" | "array", dfault: "dict"
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
Control the type of the returned object: a dict (keys are unit_ids) or an numpy array.

Returns
-------
dict
Dictionary with unit_ids as key and number of spikes as values
dict or numpy.array
Dict : Dictionary with unit_ids as key and number of spikes as values
Numpy array : array of size len(unit_ids) in the same orderas unit_ids.
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
"""
num_spikes = {}
num_spikes = np.zeros(self.unit_ids.size, dtype="int64")

# speed strategy by order
# 1. if _cached_spike_trains have all units then use it
# 2. if _cached_spike_vector is not non use it
# 3. loop with get_unit_spike_train

# check if all spiketrains are cached
if len(self._cached_spike_trains) == self.get_num_segments():
all_spiketrain_are_cached = True
for segment_index in range(self.get_num_segments()):
if len(self._cached_spike_trains[segment_index]) != self.unit_ids.size:
all_spiketrain_are_cached = False
break
else:
all_spiketrain_are_cached = False

if self._cached_spike_trains is not None:
for unit_id in self.unit_ids:
n = 0
if all_spiketrain_are_cached or self._cached_spike_vector is None:
# case one 1 or 3
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
for unit_index, unit_id in enumerate(self.unit_ids):
for segment_index in range(self.get_num_segments()):
st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
n += st.size
num_spikes[unit_id] = n
else:
num_spikes[unit_index] += st.size
elif self._cached_spike_vector is not None:
# case 2
spike_vector = self.to_spike_vector()
unit_indices, counts = np.unique(spike_vector["unit_index"], return_counts=True)
for unit_index, unit_id in enumerate(self.unit_ids):
if unit_index in unit_indices:
idx = np.argmax(unit_indices == unit_index)
num_spikes[unit_id] = counts[idx]
else: # This unit has no spikes, hence it's not in the counts array.
num_spikes[unit_id] = 0
num_spikes[unit_indices] = counts


return num_spikes
if outputs == "array":
return num_spikes
elif outputs == "dict":
num_spikes = dict(zip(self.unit_ids, num_spikes))
return num_spikes
else:
raise ValueError("count_num_spikes_per_unit() output must be dict or array")
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved

def count_total_num_spikes(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def test_BaseSorting():
spikes = sorting.to_spike_vector(extremum_channel_inds={0: 15, 1: 5, 2: 18})
# print(spikes)

num_spikes_per_unit = sorting.count_num_spikes_per_unit()
num_spikes_per_unit = sorting.count_num_spikes_per_unit(outputs="dict")
num_spikes_per_unit = sorting.count_num_spikes_per_unit(outputs="array")
total_spikes = sorting.count_total_num_spikes()

# select units
Expand Down
7 changes: 3 additions & 4 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_potential_auto_merge(

# STEP 1 :
if "min_spikes" in steps:
num_spikes = np.array(list(sorting.count_num_spikes_per_unit().values()))
num_spikes = sorting.count_num_spikes_per_unit(outputs="array")
to_remove = num_spikes < minimum_spikes
pair_mask[to_remove, :] = False
pair_mask[:, to_remove] = False
Expand Down Expand Up @@ -255,17 +255,16 @@ def compute_correlogram_diff(

# Index of the middle of the correlograms.
m = correlograms_smoothed.shape[2] // 2
num_spikes = sorting.count_num_spikes_per_unit()
num_spikes = sorting.count_num_spikes_per_unit(outputs="array")

corr_diff = np.full((n, n), np.nan, dtype="float64")
for unit_ind1 in range(n):
for unit_ind2 in range(unit_ind1 + 1, n):
if not pair_mask[unit_ind1, unit_ind2]:
continue

unit_id1, unit_id2 = unit_ids[unit_ind1], unit_ids[unit_ind2]
num1, num2 = num_spikes[unit_ind1], num_spikes[unit_ind2]

num1, num2 = num_spikes[unit_id1], num_spikes[unit_id2]
# Weighted window (larger unit imposes its window).
win_size = int(round((num1 * win_sizes[unit_ind1] + num2 * win_sizes[unit_ind2]) / (num1 + num2)))
# Plage of indices where correlograms are inside the window.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/remove_redundant.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def remove_redundant_units(
else:
remove_unit_ids.append(u2)
elif remove_strategy == "max_spikes":
num_spikes = sorting.count_num_spikes_per_unit()
num_spikes = sorting.count_num_spikes_per_unit(outputs="dict")
for u1, u2 in redundant_unit_pairs:
if num_spikes[u1] < num_spikes[u2]:
remove_unit_ids.append(u1)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), uni
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit()
spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit(outputs="dict")
sorting = waveform_extractor.sorting
spikes = sorting.to_spike_vector(concatenated=False)

Expand Down Expand Up @@ -683,7 +683,7 @@ def compute_amplitude_cv_metrics(
sorting = waveform_extractor.sorting
total_duration = waveform_extractor.get_total_duration()
spikes = sorting.to_spike_vector()
num_spikes = sorting.count_num_spikes_per_unit()
num_spikes = sorting.count_num_spikes_per_unit(outputs="dict")
if unit_ids is None:
unit_ids = sorting.unit_ids

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_depths.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
unit_amplitudes = get_template_extremum_amplitude(we, peak_sign=peak_sign)
unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids])

num_spikes = np.array(list(we.sorting.count_num_spikes_per_unit().values()))
num_spikes = we.sorting.count_num_spikes_per_unit(outputs="array")

plot_data = dict(
unit_depths=unit_depths,
Expand Down