diff --git a/rstar/src/algorithm/nearest_neighbor.rs b/rstar/src/algorithm/nearest_neighbor.rs index 515ed5a..cade86e 100644 --- a/rstar/src/algorithm/nearest_neighbor.rs +++ b/rstar/src/algorithm/nearest_neighbor.rs @@ -3,6 +3,7 @@ use crate::point::{min_inline, Point}; use crate::{Envelope, PointDistance, RTreeObject}; use alloc::{collections::BinaryHeap, vec, vec::Vec}; +use core::mem::replace; use heapless::binary_heap as static_heap; use num_traits::Bounded; @@ -50,7 +51,7 @@ where { pub fn new(root: &'a ParentNode, query_point: ::Point) -> Self { let mut result = NearestNeighborDistance2Iterator { - nodes: BinaryHeap::with_capacity(20), + nodes: SmallHeap::new(), query_point, }; result.extend_heap(&root.children); @@ -107,7 +108,7 @@ pub struct NearestNeighborDistance2Iterator<'a, T> where T: PointDistance + 'a, { - nodes: BinaryHeap>, + nodes: SmallHeap>, query_point: ::Point, } @@ -161,20 +162,55 @@ impl SmallHeap { match self { SmallHeap::Stack(heap) => { if let Err(item) = heap.push(item) { - // FIXME: This could be done more efficiently if heapless' - // BinaryHeap had draining, owning into_iter, or would - // expose its data slice. - let mut new_heap = BinaryHeap::with_capacity(heap.len() + 1); - while let Some(old_item) = heap.pop() { - new_heap.push(old_item); - } + let capacity = heap.len() + 1; + let new_heap = self.spill(capacity); new_heap.push(item); - *self = SmallHeap::Heap(new_heap); } } SmallHeap::Heap(heap) => heap.push(item), } } + + pub fn extend(&mut self, iter: I) + where + I: ExactSizeIterator, + { + match self { + SmallHeap::Stack(heap) => { + if heap.capacity() >= heap.len() + iter.len() { + for item in iter { + if heap.push(item).is_err() { + unreachable!(); + } + } + } else { + let capacity = heap.len() + iter.len(); + let new_heap = self.spill(capacity); + new_heap.extend(iter); + } + } + SmallHeap::Heap(heap) => heap.extend(iter), + } + } + + #[cold] + fn spill(&mut self, capacity: usize) -> &mut BinaryHeap { + let new_heap = BinaryHeap::with_capacity(capacity); + let old_heap = replace(self, SmallHeap::Heap(new_heap)); + + let new_heap = match self { + SmallHeap::Heap(new_heap) => new_heap, + SmallHeap::Stack(_) => unreachable!(), + }; + let old_heap = match old_heap { + SmallHeap::Stack(old_heap) => old_heap, + SmallHeap::Heap(_) => unreachable!(), + }; + + new_heap.extend(old_heap.into_vec()); + + new_heap + } } pub fn nearest_neighbor( @@ -250,26 +286,25 @@ pub fn nearest_neighbors( where T: PointDistance, { - let mut nearest_neighbors = NearestNeighborIterator::new(node, query_point.clone()); + let mut nearest_neighbors = NearestNeighborDistance2Iterator::new(node, query_point.clone()); - let first_nearest_neighbor = match nearest_neighbors.next() { - None => return vec![], // If we have an empty tree, just return an empty vector. - Some(nn) => nn, + let (first, first_distance_2) = match nearest_neighbors.next() { + Some(item) => item, + // If we have an empty tree, just return an empty vector. + None => return Vec::new(), }; // The result will at least contain the first nearest neighbor. - let mut result = vec![first_nearest_neighbor]; - - // We compute the distance to the first nearest neighbor, and use - // that distance to filter out the rest of the nearest neighbors that are farther - // than this first neighbor. - let distance = first_nearest_neighbor.envelope().distance_2(&query_point); - nearest_neighbors - .take_while(|nearest_neighbor| { - let next_distance = nearest_neighbor.envelope().distance_2(&query_point); - next_distance == distance - }) - .for_each(|nearest_neighbor| result.push(nearest_neighbor)); + let mut result = vec![first]; + + // Use the distance to the first nearest neighbor + // to filter out the rest of the nearest neighbors + // that are farther than this first neighbor. + result.extend( + nearest_neighbors + .take_while(|(_, next_distance_2)| next_distance_2 == &first_distance_2) + .map(|(next, _)| next), + ); result }