diff --git a/src/main/scala/com/spark3d/spatialOperator/SpatialQuery.scala b/src/main/scala/com/spark3d/spatialOperator/SpatialQuery.scala index e9f0c9a..936f6f5 100644 --- a/src/main/scala/com/spark3d/spatialOperator/SpatialQuery.scala +++ b/src/main/scala/com/spark3d/spatialOperator/SpatialQuery.scala @@ -18,13 +18,15 @@ package com.astrolabsoftware.spark3d.spatialOperator import com.astrolabsoftware.spark3d.geometryObjects.Shape3D.Shape3D import com.astrolabsoftware.spark3d.utils.GeometryObjectComparator -import com.astrolabsoftware.spark3d.utils.Utils.takeOrdered +import com.astrolabsoftware.spark3d.utils.BoundedUniquePriorityQueue import org.apache.spark.rdd.RDD import com.astrolabsoftware.spark3d.spatialPartitioning._ +import scala.collection.mutable import scala.collection.mutable.{HashSet, ListBuffer, PriorityQueue} import scala.reflect.ClassTag import scala.util.control.Breaks._ +import org.apache.spark.util.collection.{Utils => collectionUtils} object SpatialQuery { @@ -38,9 +40,8 @@ object SpatialQuery { * @param k number of nearest neighbors are to be found * @return knn */ - def KNN[T <: Shape3D: ClassTag](queryObject: T, rdd: RDD[T], k: Int, unique: Boolean = false): List[T] = { -// val knn = rdd.takeOrdered(k)(new GeometryObjectComparator[B](queryObject.center)) - val knn = takeOrdered[T](rdd, k, queryObject, unique)(new GeometryObjectComparator[T](queryObject.center)) + def KNN[A <: Shape3D: ClassTag, B <:Shape3D: ClassTag](queryObject: A, rdd: RDD[B], k: Int): List[B] = { + val knn = rdd.takeOrdered(k)(new GeometryObjectComparator[B](queryObject.center)) knn.toList } @@ -125,4 +126,29 @@ object SpatialQuery { knn_f.toList } + private def takeOrdered[A <: Shape3D: ClassTag](rdd: RDD[A], num: Int, queryObject: A, unique: Boolean = false)(implicit ord: Ordering[A]): Array[A] = { + + if (unique) { + if (num == 0) { + Array.empty + } else { + val mapRDDs = rdd.mapPartitions { items => + val queue = new BoundedUniquePriorityQueue[A](num)(ord.reverse) + queue ++= collectionUtils.takeOrdered(items, num)(ord) + Iterator.single(queue) + } + if (mapRDDs.partitions.length == 0) { + Array.empty + } else { + mapRDDs.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray.sorted(ord) + } + } + + } + + return rdd.takeOrdered(num)(new GeometryObjectComparator[A](queryObject.center)) + } } diff --git a/src/main/scala/com/spark3d/utils/Utils.scala b/src/main/scala/com/spark3d/utils/Utils.scala index 9263e60..3cd250c 100644 --- a/src/main/scala/com/spark3d/utils/Utils.scala +++ b/src/main/scala/com/spark3d/utils/Utils.scala @@ -15,14 +15,7 @@ */ package com.astrolabsoftware.spark3d.utils -import com.astrolabsoftware.spark3d.geometryObjects.Shape3D.Shape3D import com.astrolabsoftware.spark3d.geometryObjects._ -import com.google.common.collect.{Ordering => GuavaOrdering} - -import org.apache.spark.rdd.RDD - -import scala.reflect.ClassTag -import scala.collection.JavaConverters._ object Utils { @@ -109,38 +102,4 @@ object Utils { ra } } - - def takeOrdered[T <: Shape3D: ClassTag](rdd: RDD[T], num: Int, queryObject: T, unique: Boolean = false)(ord: Ordering[T]): Array[T] = { - - if (unique) { - if (num == 0) { - Array.empty - } else { - val mapRDDs = rdd.mapPartitions { items => - val queue = new BoundedUniquePriorityQueue[T](num)(ord.reverse) - queue ++= takeOrdered(items, num)(ord) - Iterator.single(queue) - } - if (mapRDDs.partitions.length == 0) { - return Array.empty - } else { - return mapRDDs.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord) - } - } - - } - - return rdd.takeOrdered(num)(new GeometryObjectComparator[T](queryObject.center)) - } - - private def takeOrdered[T](input: Iterator[T], num: Int)(implicit ord: Ordering[T]): Iterator[T] = { - val ordering = new GuavaOrdering[T] { - override def compare(l: T, r: T): Int = ord.compare(l, r) - } - ordering.leastOf(input.asJava, num).iterator.asScala - } - }