-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve and benchmark auto-merging functions #2934
Changes from all commits
0087705
57f40d8
ec92c01
8e39954
f0d8378
f453437
d5a541d
08c5583
779b619
a811f66
51517ab
2fa8ace
fce51e8
1693e07
9377d2c
f67b05b
0c2b502
9cef54e
6cdddb7
8b05581
3f9f556
5130db6
50135da
75963a4
8f7e2a0
d8e7537
7e766b0
fa48c56
5d950a2
bfbea8c
9ecd241
4fe600b
3ccb5f9
96be6c8
48329eb
dd52f77
20e6ba9
60fc057
630273c
3252e26
bb2c9a6
043f2a0
c0d0333
fea71b4
4bd5fb0
982c065
c51e79d
31d8180
bc0898e
719a688
da862d1
9a69b14
2eee74e
e153c15
cd08b35
bea35ee
4408a66
d6a9c8d
df937cf
57749ba
82d021c
b956e26
3f9f84c
ba459c6
1a1a54b
283aae9
0bde7bb
a71bd10
86e73e1
88c1bc8
9195a75
4ab001d
b7f54d7
9d8a699
00f0a9f
7bc1c29
947a35d
30a7d36
3c0bb86
7468a64
5ff19a1
b0dab64
7f3d365
76cb82d
5f4fd0e
5087781
4643334
5d0277e
fb6d1ba
a48ac2e
f560406
8aca009
2c07c6e
65b01be
f0499e6
f7bd29b
d777fae
56c3666
2d8df38
852521c
c806d02
c9cbd9e
476fc31
9b30b5e
b2f2e8a
e7f6655
cc066f1
640446a
7f6ed93
f201f71
c287f25
6098de8
1a29125
80f914d
036a729
2d3fbf1
b9b0dbb
4658758
0056eaf
75a8804
8afab01
efc621c
35b71ec
2ad54dc
43fe534
4358040
0a1a400
4856aeb
03d9ac6
f8b8e6c
b5b95de
e64542f
07590ce
56bbe19
5c6edf8
e0165b3
2427c22
31ddfec
0c7458e
2df7ad3
4c2bbcd
30f3617
4db85c7
10fce9d
816f2fd
e437670
835ac4f
750eba0
3058a2f
c73e860
e5b02a5
55460a9
599755c
63da29e
c3e2115
f609b23
37ba955
a67b4d3
2ac4898
1f323f9
016d7cc
71c4876
0607450
1a18a05
478c173
25f740a
9e9c3fd
a0bb321
37c7fb7
ae51c3a
e1c3c31
bb127c9
fc36bdc
6c85de4
722ff6a
d6d673a
d6bd776
01e5cc1
84dea85
d697b8c
4d34a61
bd61e1b
6a6073f
a52dc64
50f6798
8449ee9
026bab4
ef9c25c
4150d31
b86f3a1
f906059
2678d54
4dccd51
2700fb0
b6310de
de21fe8
53f05e5
f3c2d7b
4df391e
87158b5
98edbdb
85623ce
164fa58
61cde8c
c6e1e00
d3bb0a2
c9b57d9
c891245
64d2780
878fe99
09eeef3
2b0c6bc
44b3e05
67526f5
09c5903
7ebabe1
0134d0d
5c327fe
5d88ec2
b88909d
15934aa
e15c676
d4fb25c
c3f11aa
787d3a1
711d0ff
3a72d1d
ba9d955
b97f60a
b0a6212
faebd26
e1ef2a0
10a1c78
970dd95
61fc29b
d6a7aee
c9d673d
5e99875
f18124a
fb27a6b
6ec9940
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -598,3 +598,104 @@ def get_traces( | |
|
||
def get_num_samples(self) -> int: | ||
return self.num_samples | ||
|
||
|
||
def split_sorting_by_times( | ||
sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None | ||
): | ||
sa = sorting_analyzer | ||
sorting = sa.sorting | ||
rng = np.random.RandomState(seed) | ||
|
||
sorting_split = sorting.select_units(sorting.unit_ids) | ||
split_units = [] | ||
original_units = [] | ||
nb_splits = int(splitting_probability * len(sorting.unit_ids)) | ||
if unit_ids is None: | ||
select_from = sorting.unit_ids | ||
if min_snr is not None: | ||
if sa.get_extension("noise_levels") is None: | ||
sa.compute("noise_levels") | ||
if sa.get_extension("quality_metrics") is None: | ||
sa.compute("quality_metrics", metric_names=["snr"]) | ||
|
||
snr = sa.get_extension("quality_metrics").get_data()["snr"].values | ||
select_from = select_from[snr > min_snr] | ||
|
||
to_split_ids = rng.choice(select_from, nb_splits, replace=False) | ||
else: | ||
to_split_ids = unit_ids | ||
|
||
import spikeinterface.curation as scur | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be sfaer to not use this and use directy numpy like the other function |
||
|
||
for unit in to_split_ids: | ||
num_spikes = len(sorting_split.get_unit_spike_train(unit)) | ||
indices = np.zeros(num_spikes, dtype=int) | ||
indices[: num_spikes // 2] = (rng.rand(num_spikes // 2) < partial_split_prob).astype(int) | ||
indices[num_spikes // 2 :] = (rng.rand(num_spikes - num_spikes // 2) < 1 - partial_split_prob).astype(int) | ||
sorting_split = scur.split_unit_sorting( | ||
sorting_split, split_unit_id=unit, indices_list=indices, properties_policy="remove" | ||
) | ||
split_units.append(sorting_split.unit_ids[-2:]) | ||
original_units.append(unit) | ||
return sorting_split, split_units | ||
|
||
|
||
def split_sorting_by_amplitudes(sorting_analyzer, splitting_probability=0.5, unit_ids=None, min_snr=None, seed=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would put this function in another file |
||
""" | ||
Fonction used to split a sorting based on the amplitudes of the units. This | ||
might be used for benchmarking meta merging step (see components) | ||
""" | ||
|
||
sa = sorting_analyzer | ||
if sa.get_extension("spike_amplitudes") is None: | ||
sa.compute("spike_amplitudes") | ||
|
||
rng = np.random.RandomState(seed) | ||
|
||
from spikeinterface.core.numpyextractors import NumpySorting | ||
from spikeinterface.core.template_tools import get_template_extremum_channel | ||
|
||
extremum_channel_inds = get_template_extremum_channel(sa, outputs="index") | ||
spikes = sa.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) | ||
new_spikes = spikes.copy() | ||
amplitudes = sa.get_extension("spike_amplitudes").get_data() | ||
nb_splits = int(splitting_probability * len(sa.sorting.unit_ids)) | ||
|
||
if unit_ids is None: | ||
select_from = sa.sorting.unit_ids | ||
if min_snr is not None: | ||
if sa.get_extension("noise_levels") is None: | ||
sa.compute("noise_levels") | ||
if sa.get_extension("quality_metrics") is None: | ||
sa.compute("quality_metrics", metric_names=["snr"]) | ||
|
||
snr = sa.get_extension("quality_metrics").get_data()["snr"].values | ||
select_from = select_from[snr > min_snr] | ||
to_split_ids = rng.choice(select_from, nb_splits, replace=False) | ||
else: | ||
to_split_ids = unit_ids | ||
|
||
max_index = np.max(spikes["unit_index"]) | ||
new_unit_ids = list(sa.sorting.unit_ids.copy()) | ||
splitted_pairs = [] | ||
for unit_id in to_split_ids: | ||
ind_mask = spikes["unit_index"] == sa.sorting.id_to_index(unit_id) | ||
|
||
m = amplitudes[ind_mask].mean() | ||
s = amplitudes[ind_mask].std() | ||
thresh = m + 0.2 * s | ||
|
||
amplitude_mask = amplitudes > thresh | ||
mask = ind_mask & amplitude_mask | ||
new_spikes["unit_index"][mask] = max_index + 1 | ||
|
||
amplitude_mask = (amplitudes > m) * (amplitudes < thresh) | ||
mask = ind_mask & amplitude_mask | ||
new_spikes["unit_index"][mask] = (max_index + 1) * rng.rand(np.sum(mask)) > 0.5 | ||
max_index += 1 | ||
new_unit_ids += [max(new_unit_ids) + 1] | ||
splitted_pairs += [(unit_id, new_unit_ids[-1])] | ||
|
||
new_sorting = NumpySorting(new_spikes, sampling_frequency=sa.sampling_frequency, unit_ids=new_unit_ids) | ||
return new_sorting, splitted_pairs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put this function in anotehr file