From da8b15a9e40605a678ba0ebfcc38e4924192e8a2 Mon Sep 17 00:00:00 2001 From: "yiming.xu" <100650920@qq.com> Date: Thu, 21 Nov 2019 10:11:00 +0800 Subject: [PATCH] #73 support limit offset --- .../plans/logical/basicLogicalOperators.scala | 35 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 4 + .../spark/sql/execution/SparkStrategies.scala | 21 +++ .../apache/spark/sql/execution/limit.scala | 123 ++++++++++++++++++ 4 files changed, 183 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 2eee94364d84e..9a8e3427c5d55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -782,6 +782,41 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } } +object LimitRange { + def apply(startExpr: Expression, endExpr: Expression, child: LogicalPlan): UnaryNode = { + LimitRange0(startExpr, endExpr, LocalLimit(endExpr, child)) + } + + def unapply(p: LimitRange0): Option[(Expression, Expression, LogicalPlan)] = { + p match { + case LimitRange0(le0, le1, LocalLimit(le2, child)) if le1 == le2 => Some((le0, le1, child)) + case _ => None + } + } +} +/** + * A global (coordinated) limit. This operator can emit at most `limitExpr` number in total. + * + * See [[Limit]] for more information. + */ +case class LimitRange0(startExpr: Expression, endExpr: Expression, child: LogicalPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + (endExpr, endExpr) match { + case (IntegerLiteral(start), IntegerLiteral(end)) => Some(end - start) + case _ => None + } + } +} + + +/** + * Aliased subquery. + * + * @param alias the alias identifier for this subquery. + * @param child the logical plan of this subquery. + */ case class SubqueryAlias( alias: String, child: LogicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1acbad960f1bf..c7441db854f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1625,6 +1625,10 @@ class Dataset[T] private[sql]( Limit(Literal(n), logicalPlan) } + def limitRange(start: Int, end: Int): Dataset[T] = withTypedPlan { + LimitRange(Literal(start), Literal(end), logicalPlan) + } + /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 843ce63161220..6e7db5f009305 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -72,8 +72,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( limit, order, projectList, planLater(child)) :: Nil + case logical.LimitRange(IntegerLiteral(start), IntegerLiteral(end), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedRangeAndProjectExec(start, end, order, + projectList, planLater(child)) :: Nil + case logical.LimitRange(IntegerLiteral(start), IntegerLiteral(end), + logical.Sort(order, true, child)) => + execution.TakeOrderedRangeAndProjectExec(start, end, order, + child.output, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.CollectLimitExec(limit, planLater(child)) :: Nil + case logical.LimitRange(IntegerLiteral(start), IntegerLiteral(limit), child) => + execution.CollectLimitRangeExec(start, limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => @@ -82,6 +92,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( limit, order, projectList, planLater(child)) :: Nil + case logical.LimitRange(IntegerLiteral(start), IntegerLiteral(end), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedRangeAndProjectExec(start, end, order, + projectList, planLater(child)) :: Nil + case logical.LimitRange(IntegerLiteral(start), IntegerLiteral(end), + logical.Sort(order, true, child)) => + execution.TakeOrderedRangeAndProjectExec(start, end, order, + child.output, planLater(child)) :: Nil case _ => Nil } } @@ -417,6 +435,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.LocalLimitExec(limit, planLater(child)) :: Nil case logical.GlobalLimit(IntegerLiteral(limit), child) => execution.GlobalLimitExec(limit, planLater(child)) :: Nil + case logical.LimitRange(IntegerLiteral(start), + IntegerLiteral(limit), child) => + execution.RangeLimitExec(start, limit, planLater(child)) :: Nil case logical.Union(unionChildren) => execution.UnionExec(unionChildren.map(planLater)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 757fe2185d302..9608781e3d07f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -46,6 +46,26 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } } +/** + * Take the first `limit` elements and collect them to a single partition. + * + * This operator will be used when a logical `Limit` operation is the final operator in an + * logical plan, which happens when the user is collecting results back to the driver. + */ +case class CollectLimitRangeExec(start: Int, end: Int, child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition + override def executeCollect(): Array[InternalRow] = child.executeTake(end).drop(start) + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + protected override def doExecute(): RDD[InternalRow] = { + val locallyLimited = child.execute().mapPartitionsInternal(_.take(end)) + val shuffled = new ShuffledRowRDD( + ShuffleExchange.prepareShuffleDependency( + locallyLimited, child.output, SinglePartition, serializer)) + shuffled.mapPartitionsInternal(_.slice(start, end)) + } +} + /** * Helper trait which defines methods that are shared by both * [[LocalLimitExec]] and [[GlobalLimitExec]]. @@ -111,6 +131,44 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { override def outputOrdering: Seq[SortOrder] = child.outputOrdering } +/** + * Take the first `limit` elements of the child's single output partition. + */ +case class RangeLimitExec(start: Int, limit: Int, child: SparkPlan) extends BaseLimitExec { + + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("stopEarly", s""" + @Override + protected boolean stopEarly() { + return $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | $countTerm += 1; + | if ( $countTerm > $start && $countTerm <= $limit) { + | ${consume(ctx, input)} + | } if($countTerm > $limit) { + | $stopEarly = true; + | } + """.stripMargin + } + + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + iter.slice(start, limit) + } +} + /** * Take the first limit elements as defined by the sortOrder, and do projection if needed. * This is logically equivalent to having a Limit operator after a [[SortExec]] operator, @@ -173,3 +231,68 @@ case class TakeOrderedAndProjectExec( s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } } +/** + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a Limit operator after a [[SortExec]] operator, + * or having a [[ProjectExec]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. + */ +case class TakeOrderedRangeAndProjectExec( + start: Int, + end: Int, + sortOrder: Seq[SortOrder], + projectList: Seq[NamedExpression], + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = { + projectList.map(_.toAttribute) + } + + override def executeCollect(): Array[InternalRow] = { + val ord = new LazilyGeneratedOrdering(sortOrder, child.output) + val data = child.execute().map(_.copy()).takeOrdered(end)(ord).drop(start) + if (projectList != child.output) { + val proj = UnsafeProjection.create(projectList, child.output) + data.map(r => proj(r).copy()) + } else { + data + } + } + + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + + protected override def doExecute(): RDD[InternalRow] = { + val ord = new LazilyGeneratedOrdering(sortOrder, child.output) + val localTopK: RDD[InternalRow] = { + child.execute().map(_.copy()).mapPartitions { iter => + org.apache.spark.util.collection.Utils.takeOrdered(iter, end)(ord) + } + } + val shuffled = new ShuffledRowRDD( + ShuffleExchange.prepareShuffleDependency( + localTopK, child.output, SinglePartition, serializer)) + shuffled.mapPartitions { iter => + val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), end)(ord) + .drop(start) + if (projectList != child.output) { + val proj = UnsafeProjection.create(projectList, child.output) + topK.map(r => proj(r)) + } else { + topK + } + } + } + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputPartitioning: Partitioning = SinglePartition + + override def simpleString: String = { + val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]") + val outputString = Utils.truncatedString(output, "[", ",", "]") + + s"TakeOrderedRangeAndProject" + + s"(start=$start, end=$end, orderBy=$orderByString, output=$outputString)" + } +}