diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 49a5650622..1c8661d12d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1336,8 +1336,17 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=20., max_iteration=100, distance_strict=False, seed=None): +def generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=20.0, + max_iteration=100, + distance_strict=False, + seed=None, +): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") @@ -1364,7 +1373,7 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu else: # random only bad ones in the previous set renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] - + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) @@ -1374,8 +1383,10 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu if not solution_found: if distance_strict: - raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " - "You can use distance_strict=False or reduce minimum distance") + raise ValueError( + f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance" + ) else: warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 582120ac51..7b51abcccb 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -292,24 +292,29 @@ def test_generate_single_fake_waveform(): def test_generate_unit_locations(): - seed = 0 probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) channel_locations = probe.contact_positions num_units = 100 - minimum_distance = 20. - unit_locations = generate_unit_locations(num_units, channel_locations, - margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=minimum_distance, max_iteration=500, - distance_strict=False, seed=seed) + minimum_distance = 20.0 + unit_locations = generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=minimum_distance, + max_iteration=500, + distance_strict=False, + seed=seed, + ) distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) dist_flat = np.triu(distances, k=1).flatten() - dist_flat = dist_flat[dist_flat>0] + dist_flat = dist_flat[dist_flat > 0] assert np.all(dist_flat > minimum_distance) - # import matplotlib.pyplot as plt # fig, ax = plt.subplots() # ax.hist(dist_flat, bins = np.arange(0, 400, 10))