Skip to content

Commit

Permalink
Fix #140 and add tests. (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Feb 11, 2020
1 parent bc6ccc0 commit cd73cb9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
8 changes: 7 additions & 1 deletion rtree/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,12 @@ def nearest(self, coordinates, num_results=1, objects=False):
return self._nearest_obj(coordinates, num_results, objects)
p_mins, p_maxs = self.get_coordinate_pointers(coordinates)

# p_num_results is an input and output for C++ lib
# as an input it says "get n closest neighbors"
# but if multiple neighbors are at the same distance, both will be returned
# so the number of returned neighbors may be > p_num_results
# thus p_num_results.contents.value gets set as an output by the C++ lib
# to indicate the actual number of results for _get_ids to use
p_num_results = ctypes.pointer(ctypes.c_uint64(num_results))

it = ctypes.pointer(ctypes.c_int64())
Expand All @@ -857,7 +863,7 @@ def nearest(self, coordinates, num_results=1, objects=False):
ctypes.byref(it),
p_num_results)

return self._get_ids(it, min(num_results, p_num_results.contents.value))
return self._get_ids(it, p_num_results.contents.value)

def _nearestTP(self, coordinates, velocities, times, num_results=1, objects=False):
p_mins, p_maxs = self.get_coordinate_pointers(coordinates)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,38 @@ def test_nearest_basic(self):
idx.add(i, (start, 1, stop, 1))
hits = sorted(idx.nearest((13, 0, 20, 2), 3))
self.assertEqual(hits, [3, 4, 5])

def test_nearest_equidistant(self):
"""Test that if records are equidistant, both are returned."""
point = (0, 0)
small_box = (-10, -10, 10, 10)
large_box = (-50, -50, 50, 50)

idx = index.Index()
idx.insert(0, small_box)
idx.insert(1, large_box)
self.assertEqual(list(idx.nearest(point, 2)), [0, 1])
self.assertEqual(list(idx.nearest(point, 1)), [0, 1])

idx.insert(2, (0, 0))
self.assertEqual(list(idx.nearest(point, 2)), [0, 1, 2])
self.assertEqual(list(idx.nearest(point, 1)), [0, 1, 2])

idx = index.Index()
idx.insert(0, small_box)
idx.insert(1, large_box)
idx.insert(2, (50, 50)) # point on top right vertex of large_box
point = (51, 51) # right outside of large_box
self.assertEqual(list(idx.nearest(point, 2)), [1, 2])
self.assertEqual(list(idx.nearest(point, 1)), [1, 2])

idx = index.Index()
idx.insert(0, small_box)
idx.insert(1, large_box)
idx.insert(2, (51, 51)) # point right outside on top right vertex of large_box
point = (51, 52) # shifted 1 unit up from the point above
self.assertEqual(list(idx.nearest(point, 2)), [2, 1])
self.assertEqual(list(idx.nearest(point, 1)), [2])


def test_nearest_object(self):
Expand Down

0 comments on commit cd73cb9

Please sign in to comment.