Skip to content

Commit

Permalink
Use SmallHeap also for the iterator-based nearest neighbour search an…
Browse files Browse the repository at this point in the history
…d optimize its spill implementation. (#154)

* Use SmallHeap also for the iterator-based nearest neighbour search and optimize its spill implementation.

* Avoid recomputing the squared-distance in the implementation of nearest_neighbors.
  • Loading branch information
adamreichold authored Jan 23, 2024
1 parent f9973cf commit 1812101
Showing 1 changed file with 61 additions and 26 deletions.
87 changes: 61 additions & 26 deletions rstar/src/algorithm/nearest_neighbor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::point::{min_inline, Point};
use crate::{Envelope, PointDistance, RTreeObject};

use alloc::{collections::BinaryHeap, vec, vec::Vec};
use core::mem::replace;
use heapless::binary_heap as static_heap;
use num_traits::Bounded;

Expand Down Expand Up @@ -50,7 +51,7 @@ where
{
pub fn new(root: &'a ParentNode<T>, query_point: <T::Envelope as Envelope>::Point) -> Self {
let mut result = NearestNeighborDistance2Iterator {
nodes: BinaryHeap::with_capacity(20),
nodes: SmallHeap::new(),
query_point,
};
result.extend_heap(&root.children);
Expand Down Expand Up @@ -107,7 +108,7 @@ pub struct NearestNeighborDistance2Iterator<'a, T>
where
T: PointDistance + 'a,
{
nodes: BinaryHeap<RTreeNodeDistanceWrapper<'a, T>>,
nodes: SmallHeap<RTreeNodeDistanceWrapper<'a, T>>,
query_point: <T::Envelope as Envelope>::Point,
}

Expand Down Expand Up @@ -161,20 +162,55 @@ impl<T: Ord> SmallHeap<T> {
match self {
SmallHeap::Stack(heap) => {
if let Err(item) = heap.push(item) {
// FIXME: This could be done more efficiently if heapless'
// BinaryHeap had draining, owning into_iter, or would
// expose its data slice.
let mut new_heap = BinaryHeap::with_capacity(heap.len() + 1);
while let Some(old_item) = heap.pop() {
new_heap.push(old_item);
}
let capacity = heap.len() + 1;
let new_heap = self.spill(capacity);
new_heap.push(item);
*self = SmallHeap::Heap(new_heap);
}
}
SmallHeap::Heap(heap) => heap.push(item),
}
}

pub fn extend<I>(&mut self, iter: I)
where
I: ExactSizeIterator<Item = T>,
{
match self {
SmallHeap::Stack(heap) => {
if heap.capacity() >= heap.len() + iter.len() {
for item in iter {
if heap.push(item).is_err() {
unreachable!();
}
}
} else {
let capacity = heap.len() + iter.len();
let new_heap = self.spill(capacity);
new_heap.extend(iter);
}
}
SmallHeap::Heap(heap) => heap.extend(iter),
}
}

#[cold]
fn spill(&mut self, capacity: usize) -> &mut BinaryHeap<T> {
let new_heap = BinaryHeap::with_capacity(capacity);
let old_heap = replace(self, SmallHeap::Heap(new_heap));

let new_heap = match self {
SmallHeap::Heap(new_heap) => new_heap,
SmallHeap::Stack(_) => unreachable!(),
};
let old_heap = match old_heap {
SmallHeap::Stack(old_heap) => old_heap,
SmallHeap::Heap(_) => unreachable!(),
};

new_heap.extend(old_heap.into_vec());

new_heap
}
}

pub fn nearest_neighbor<T>(
Expand Down Expand Up @@ -250,26 +286,25 @@ pub fn nearest_neighbors<T>(
where
T: PointDistance,
{
let mut nearest_neighbors = NearestNeighborIterator::new(node, query_point.clone());
let mut nearest_neighbors = NearestNeighborDistance2Iterator::new(node, query_point.clone());

let first_nearest_neighbor = match nearest_neighbors.next() {
None => return vec![], // If we have an empty tree, just return an empty vector.
Some(nn) => nn,
let (first, first_distance_2) = match nearest_neighbors.next() {
Some(item) => item,
// If we have an empty tree, just return an empty vector.
None => return Vec::new(),
};

// The result will at least contain the first nearest neighbor.
let mut result = vec![first_nearest_neighbor];

// We compute the distance to the first nearest neighbor, and use
// that distance to filter out the rest of the nearest neighbors that are farther
// than this first neighbor.
let distance = first_nearest_neighbor.envelope().distance_2(&query_point);
nearest_neighbors
.take_while(|nearest_neighbor| {
let next_distance = nearest_neighbor.envelope().distance_2(&query_point);
next_distance == distance
})
.for_each(|nearest_neighbor| result.push(nearest_neighbor));
let mut result = vec![first];

// Use the distance to the first nearest neighbor
// to filter out the rest of the nearest neighbors
// that are farther than this first neighbor.
result.extend(
nearest_neighbors
.take_while(|(_, next_distance_2)| next_distance_2 == &first_distance_2)
.map(|(next, _)| next),
);

result
}
Expand Down

0 comments on commit 1812101

Please sign in to comment.