diff --git a/rstar/src/algorithm/nearest_neighbor.rs b/rstar/src/algorithm/nearest_neighbor.rs index e136661..cade86e 100644 --- a/rstar/src/algorithm/nearest_neighbor.rs +++ b/rstar/src/algorithm/nearest_neighbor.rs @@ -286,26 +286,25 @@ pub fn nearest_neighbors( where T: PointDistance, { - let mut nearest_neighbors = NearestNeighborIterator::new(node, query_point.clone()); + let mut nearest_neighbors = NearestNeighborDistance2Iterator::new(node, query_point.clone()); - let first_nearest_neighbor = match nearest_neighbors.next() { - None => return vec![], // If we have an empty tree, just return an empty vector. - Some(nn) => nn, + let (first, first_distance_2) = match nearest_neighbors.next() { + Some(item) => item, + // If we have an empty tree, just return an empty vector. + None => return Vec::new(), }; // The result will at least contain the first nearest neighbor. - let mut result = vec![first_nearest_neighbor]; - - // We compute the distance to the first nearest neighbor, and use - // that distance to filter out the rest of the nearest neighbors that are farther - // than this first neighbor. - let distance = first_nearest_neighbor.envelope().distance_2(&query_point); - nearest_neighbors - .take_while(|nearest_neighbor| { - let next_distance = nearest_neighbor.envelope().distance_2(&query_point); - next_distance == distance - }) - .for_each(|nearest_neighbor| result.push(nearest_neighbor)); + let mut result = vec![first]; + + // Use the distance to the first nearest neighbor + // to filter out the rest of the nearest neighbors + // that are farther than this first neighbor. + result.extend( + nearest_neighbors + .take_while(|(_, next_distance_2)| next_distance_2 == &first_distance_2) + .map(|(next, _)| next), + ); result }