From 538d369081c1dac3d7a9032d2444a0fc88a244f5 Mon Sep 17 00:00:00 2001 From: BRAUN REMI Date: Thu, 7 Mar 2024 16:23:26 +0100 Subject: [PATCH] FIX: Fix `geometry.nearest_neighbors` when k is bigger than the number of candidates --- CHANGES.md | 1 + CI/SCRIPTS/test_geometry.py | 5 +++++ sertit/geometry.py | 4 +++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 56df4f9..607f4f8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,7 @@ ## 1.36.1 (2024-xx-xx) +- FIX: Fix `geometry.nearest_neighbors` when k is bigger than the number of candidates - DOC: Update some examples in documentation ## 1.36.0 (2024-02-27) diff --git a/CI/SCRIPTS/test_geometry.py b/CI/SCRIPTS/test_geometry.py index d5b7cfb..cc15883 100644 --- a/CI/SCRIPTS/test_geometry.py +++ b/CI/SCRIPTS/test_geometry.py @@ -188,3 +188,8 @@ def test_nearest_neighbors(): assert ( curr_dist[0] < radius ), f"distance superior to wanted distance: {curr_dist[0]} > {radius}" + + # Ensure it works with k > nof candidates + closest, distances = nearest_neighbors( + src, candidates, method="k_neighbors", k_neighbors=len(candidates) + 10 + ) diff --git a/sertit/geometry.py b/sertit/geometry.py index 4426eac..dff4f5d 100644 --- a/sertit/geometry.py +++ b/sertit/geometry.py @@ -539,7 +539,9 @@ def _get_k_nearest(src_points: list, candidates: list, k_neighbors: int, **kwarg tree = BallTree(candidates, leaf_size=15) # Find the closest points and distances - closest_dist, closest = tree.query(src_points, k=k_neighbors, **kwargs) + closest_dist, closest = tree.query( + src_points, k=min(k_neighbors, len(candidates)), **kwargs + ) # Return indices and distances return closest, closest_dist