Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SmallHeap also for the iterator-based nearest neighbour search and optimize its spill implementation. #154

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 61 additions & 26 deletions rstar/src/algorithm/nearest_neighbor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::point::{min_inline, Point};
use crate::{Envelope, PointDistance, RTreeObject};

use alloc::{collections::BinaryHeap, vec, vec::Vec};
use core::mem::replace;
use heapless::binary_heap as static_heap;
use num_traits::Bounded;

Expand Down Expand Up @@ -50,7 +51,7 @@ where
{
pub fn new(root: &'a ParentNode<T>, query_point: <T::Envelope as Envelope>::Point) -> Self {
let mut result = NearestNeighborDistance2Iterator {
nodes: BinaryHeap::with_capacity(20),
nodes: SmallHeap::new(),
query_point,
};
result.extend_heap(&root.children);
Expand Down Expand Up @@ -107,7 +108,7 @@ pub struct NearestNeighborDistance2Iterator<'a, T>
where
T: PointDistance + 'a,
{
nodes: BinaryHeap<RTreeNodeDistanceWrapper<'a, T>>,
nodes: SmallHeap<RTreeNodeDistanceWrapper<'a, T>>,
query_point: <T::Envelope as Envelope>::Point,
}

Expand Down Expand Up @@ -161,20 +162,55 @@ impl<T: Ord> SmallHeap<T> {
match self {
SmallHeap::Stack(heap) => {
if let Err(item) = heap.push(item) {
// FIXME: This could be done more efficiently if heapless'
// BinaryHeap had draining, owning into_iter, or would
// expose its data slice.
let mut new_heap = BinaryHeap::with_capacity(heap.len() + 1);
while let Some(old_item) = heap.pop() {
new_heap.push(old_item);
}
let capacity = heap.len() + 1;
let new_heap = self.spill(capacity);
new_heap.push(item);
*self = SmallHeap::Heap(new_heap);
}
}
SmallHeap::Heap(heap) => heap.push(item),
}
}

pub fn extend<I>(&mut self, iter: I)
where
I: ExactSizeIterator<Item = T>,
{
match self {
SmallHeap::Stack(heap) => {
if heap.capacity() >= heap.len() + iter.len() {
for item in iter {
if heap.push(item).is_err() {
unreachable!();
}
}
} else {
let capacity = heap.len() + iter.len();
let new_heap = self.spill(capacity);
new_heap.extend(iter);
}
}
SmallHeap::Heap(heap) => heap.extend(iter),
}
}

#[cold]
fn spill(&mut self, capacity: usize) -> &mut BinaryHeap<T> {
let new_heap = BinaryHeap::with_capacity(capacity);
let old_heap = replace(self, SmallHeap::Heap(new_heap));

let new_heap = match self {
SmallHeap::Heap(new_heap) => new_heap,
SmallHeap::Stack(_) => unreachable!(),
};
let old_heap = match old_heap {
SmallHeap::Stack(old_heap) => old_heap,
SmallHeap::Heap(_) => unreachable!(),
};

new_heap.extend(old_heap.into_vec());

new_heap
}
}

pub fn nearest_neighbor<T>(
Expand Down Expand Up @@ -250,26 +286,25 @@ pub fn nearest_neighbors<T>(
where
T: PointDistance,
{
let mut nearest_neighbors = NearestNeighborIterator::new(node, query_point.clone());
let mut nearest_neighbors = NearestNeighborDistance2Iterator::new(node, query_point.clone());

let first_nearest_neighbor = match nearest_neighbors.next() {
None => return vec![], // If we have an empty tree, just return an empty vector.
Some(nn) => nn,
let (first, first_distance_2) = match nearest_neighbors.next() {
Some(item) => item,
// If we have an empty tree, just return an empty vector.
None => return Vec::new(),
};

// The result will at least contain the first nearest neighbor.
let mut result = vec![first_nearest_neighbor];

// We compute the distance to the first nearest neighbor, and use
// that distance to filter out the rest of the nearest neighbors that are farther
// than this first neighbor.
let distance = first_nearest_neighbor.envelope().distance_2(&query_point);
nearest_neighbors
.take_while(|nearest_neighbor| {
let next_distance = nearest_neighbor.envelope().distance_2(&query_point);
next_distance == distance
})
.for_each(|nearest_neighbor| result.push(nearest_neighbor));
let mut result = vec![first];

// Use the distance to the first nearest neighbor
// to filter out the rest of the nearest neighbors
// that are farther than this first neighbor.
result.extend(
nearest_neighbors
.take_while(|(_, next_distance_2)| next_distance_2 == &first_distance_2)
.map(|(next, _)| next),
);

result
}
Expand Down