From 1be2ce144e6f7f574a278dcf7775f3a11090a2a3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 30 Oct 2023 18:57:51 +0100 Subject: [PATCH 1/2] Add a minimum distance in generate_unit_locations. --- src/spikeinterface/core/generate.py | 46 ++++++++++++++++--- .../core/tests/test_generate.py | 36 ++++++++++++++- 2 files changed, 74 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 44ea02d32c..003b9cb5b5 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1333,15 +1333,49 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None): +def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, + minimum_distance=20., max_iteration=100, distance_strict=False, seed=None): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") - for dim in (0, 1): - lim0 = np.min(channel_locations[:, dim]) - margin_um - lim1 = np.max(channel_locations[:, dim]) + margin_um - units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units) + + minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um + minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um + + units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units) + units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units) units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units) + if minimum_distance is not None: + solution_found = False + renew_inds = None + for i in range(max_iteration): + distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2) + inds0, inds1 = np.nonzero(distances < minimum_distance) + mask = inds0 != inds1 + inds0 = inds0[mask] + inds1 = inds1[mask] + + if inds0.size > 0: + if renew_inds is None: + renew_inds = np.unique(inds0) + else: + # random only bad ones in the previous set + renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] + + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) + units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) + units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) + else: + solution_found = True + break + + if not solution_found: + if distance_strict: + raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance") + else: + warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") + return units_locations @@ -1366,7 +1400,7 @@ def generate_ground_truth_recording( upsample_vector=None, generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), generate_templates_kwargs=dict(), dtype="float32", seed=None, diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9a9c61766f..582120ac51 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -4,6 +4,8 @@ import numpy as np from spikeinterface.core import load_extractor, extract_waveforms + +from probeinterface import generate_multi_columns_probe from spikeinterface.core.generate import ( generate_recording, generate_sorting, @@ -289,6 +291,35 @@ def test_generate_single_fake_waveform(): # plt.show() +def test_generate_unit_locations(): + + seed = 0 + + probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) + channel_locations = probe.contact_positions + + num_units = 100 + minimum_distance = 20. + unit_locations = generate_unit_locations(num_units, channel_locations, + margin_um=20.0, minimum_z=5.0, maximum_z=40.0, + minimum_distance=minimum_distance, max_iteration=500, + distance_strict=False, seed=seed) + distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) + dist_flat = np.triu(distances, k=1).flatten() + dist_flat = dist_flat[dist_flat>0] + assert np.all(dist_flat > minimum_distance) + + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.hist(dist_flat, bins = np.arange(0, 400, 10)) + # fig, ax = plt.subplots() + # from probeinterface.plotting import plot_probe + # plot_probe(probe, ax=ax) + # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20) + # plt.show() + + def test_generate_templates(): seed = 0 @@ -297,7 +328,7 @@ def test_generate_templates(): num_units = 10 margin_um = 15.0 channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) - unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed) + unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed) sampling_frequency = 30000.0 ms_before = 1.0 @@ -436,7 +467,8 @@ def test_generate_ground_truth_recording(): # test_noise_generator_consistency_after_dump(strategy, None) # test_generate_recording() # test_generate_single_fake_waveform() + test_generate_unit_locations() # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - test_generate_sorting_with_spikes_on_borders() + # test_generate_sorting_with_spikes_on_borders() From d692439159e08cc5de33dd6ccbbbe3e19513f550 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:44:39 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 21 ++++++++++++++----- .../core/tests/test_generate.py | 21 ++++++++++++------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 49a5650622..1c8661d12d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1336,8 +1336,17 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=20., max_iteration=100, distance_strict=False, seed=None): +def generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=20.0, + max_iteration=100, + distance_strict=False, + seed=None, +): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") @@ -1364,7 +1373,7 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu else: # random only bad ones in the previous set renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] - + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) @@ -1374,8 +1383,10 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu if not solution_found: if distance_strict: - raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " - "You can use distance_strict=False or reduce minimum distance") + raise ValueError( + f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance" + ) else: warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 582120ac51..7b51abcccb 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -292,24 +292,29 @@ def test_generate_single_fake_waveform(): def test_generate_unit_locations(): - seed = 0 probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) channel_locations = probe.contact_positions num_units = 100 - minimum_distance = 20. - unit_locations = generate_unit_locations(num_units, channel_locations, - margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=minimum_distance, max_iteration=500, - distance_strict=False, seed=seed) + minimum_distance = 20.0 + unit_locations = generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=minimum_distance, + max_iteration=500, + distance_strict=False, + seed=seed, + ) distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) dist_flat = np.triu(distances, k=1).flatten() - dist_flat = dist_flat[dist_flat>0] + dist_flat = dist_flat[dist_flat > 0] assert np.all(dist_flat > minimum_distance) - # import matplotlib.pyplot as plt # fig, ax = plt.subplots() # ax.hist(dist_flat, bins = np.arange(0, 400, 10))