diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala similarity index 70% rename from integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala rename to integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala index 06c90527d..2f59b6fba 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala @@ -11,7 +11,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull, IsNull, Or} import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Filter, LogicalPlan, Project, Union} import org.apache.spark.sql.streaming.StreamTest -class FlintSparkPPLDedupITSuite +class FlintSparkPPLDedupeITSuite extends QueryTest with LogicalPlanTestUtils with FlintPPLSuite @@ -187,7 +187,7 @@ class FlintSparkPPLDedupITSuite assert(ex.getMessage.contains("Consecutive deduplication is not supported")) } - ignore("test dedupe 2 name") { + test("test dedupe 2 name") { val frame = sql(s""" | source = $testTable| dedup 2 name | fields name | """.stripMargin) @@ -200,7 +200,7 @@ class FlintSparkPPLDedupITSuite assert(results.sorted.sameElements(expectedResults.sorted)) } - ignore("test dedupe 2 name, category") { + test("test dedupe 2 name, category") { val frame = sql(s""" | source = $testTable| dedup 2 name, category | fields name, category | """.stripMargin) @@ -225,7 +225,7 @@ class FlintSparkPPLDedupITSuite assert(results.sorted.sameElements(expectedResults.sorted)) } - ignore("test dedupe 2 name KEEPEMPTY=true") { + test("test dedupe 2 name KEEPEMPTY=true") { val frame = sql(s""" | source = $testTable| dedup 2 name KEEPEMPTY=true | fields name, category | """.stripMargin) @@ -259,7 +259,7 @@ class FlintSparkPPLDedupITSuite .sameElements(expectedResults.sorted.map(_.getAs[String](0)))) } - ignore("test dedupe 2 name, category KEEPEMPTY=true") { + test("test dedupe 2 name, category KEEPEMPTY=true") { val frame = sql(s""" | source = $testTable| dedup 2 name, category KEEPEMPTY=true | fields name, category | """.stripMargin) @@ -307,4 +307,140 @@ class FlintSparkPPLDedupITSuite | """.stripMargin)) assert(ex.getMessage.contains("Consecutive deduplication is not supported")) } + + test("test dedupe 1 category, name - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "Y"), + Row("B", "Z"), + Row("C", "X"), + Row("D", "Z"), + Row("B", "Y")) + implicit val twoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => (row.getAs(0), row.getAs(1))) + + val frame1 = sql(s""" + | source = $testTable | dedup 1 name, category + | """.stripMargin) + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 1 category, name + | """.stripMargin) + val results2: Array[Row] = frame2.drop("id").collect() + + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) + } + + test( + "test dedupe 1 category, name KEEPEMPTY=true - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "Y"), + Row("B", "Z"), + Row("C", "X"), + Row("D", "Z"), + Row("B", "Y"), + Row(null, "Y"), + Row("E", null), + Row(null, "X"), + Row("B", null), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + + val frame1 = sql(s""" + | source = $testTable | dedup 1 name, category KEEPEMPTY=true + | """.stripMargin) + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 1 category, name KEEPEMPTY=true + | """.stripMargin) + val results2: Array[Row] = frame2.drop("id").collect() + + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) + } + + test("test dedupe 2 category, name - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "X"), + Row("A", "Y"), + Row("A", "Y"), + Row("B", "Y"), + Row("B", "Z"), + Row("B", "Z"), + Row("C", "X"), + Row("C", "X"), + Row("D", "Z")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](row => { + val value = row.getAs[String](0) + if (value == null) String.valueOf(Int.MaxValue) else value + }) + + val frame1 = sql(s""" + | source = $testTable | dedup 2 name, category + | """.stripMargin) + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 2 category, name + | """.stripMargin) + val results2: Array[Row] = frame2.drop("id").collect() + + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) + } + + test( + "test dedupe 2 category, name KEEPEMPTY=true - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "X"), + Row("A", "Y"), + Row("A", "Y"), + Row("B", "Y"), + Row("B", "Z"), + Row("B", "Z"), + Row("C", "X"), + Row("C", "X"), + Row("D", "Z"), + Row(null, "Y"), + Row("E", null), + Row(null, "X"), + Row("B", null), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + + val frame1 = sql(s""" + | source = $testTable | dedup 2 name, category KEEPEMPTY=true + | """.stripMargin) + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 2 category, name KEEPEMPTY=true + | """.stripMargin) + val results2: Array[Row] = frame2.drop("id").collect() + + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index f78f75e8c..979cb712d 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -298,8 +298,11 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | dedup 1 a,b | fields a,b,c` - `source = table | dedup 1 a keepempty=true | fields a,b,c` - `source = table | dedup 1 a,b keepempty=true | fields a,b,c` -- `source = table | dedup 1 a consecutive=true| fields a,b,c` (Unsupported) -- `source = table | dedup 2 a | fields a,b,c` (Unsupported) +- `source = table | dedup 2 a | fields a,b,c` +- `source = table | dedup 2 a,b | fields a,b,c` +- `source = table | dedup 2 a keepempty=true | fields a,b,c` +- `source = table | dedup 2 a,b keepempty=true | fields a,b,c` +- `source = table | dedup 1 a consecutive=true| fields a,b,c` (Consecutive deduplication is unsupported) **Rare** - `source=accounts | rare gender` diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 8759ddcf7..46453c8a6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -17,11 +17,9 @@ import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; -import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -88,6 +86,10 @@ import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEventAndKeepEmpty; import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; @@ -350,97 +352,29 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { } visitFieldList(node.getFields(), context); // Columns to deduplicate - Seq dedupFields + Seq dedupeFields = context.retainAllNamedParseExpressions(e -> (org.apache.spark.sql.catalyst.expressions.Attribute) e); // Although we can also use the Window operator to translate this as allowedDuplication > 1 did, // adding Aggregate operator could achieve better performance. if (allowedDuplication == 1) { if (keepEmpty) { - // Union - // :- Deduplicate ['a, 'b] - // : +- Filter (isnotnull('a) AND isnotnull('b) - // : +- Project - // : +- UnresolvedRelation - // +- Filter (isnull('a) OR isnull('a)) - // +- Project - // +- UnresolvedRelation - - context.apply(p -> { - Expression isNullExpr = buildIsNullFilterExpression(node, context); - LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); - - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - LogicalPlan left = - new Deduplicate(dedupFields, - new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); - return new Union(seq(left, right), false, false); - }); - return context.getPlan(); + return retainOneDuplicateEventAndKeepEmpty(node, dedupeFields, expressionAnalyzer, context); } else { - // Deduplicate ['a, 'b] - // +- Filter (isnotnull('a) AND isnotnull('b)) - // +- Project - // +- UnresolvedRelation - - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); - // Todo DeduplicateWithinWatermark in streaming dataset? - return context.apply(p -> new Deduplicate(dedupFields, p)); + return retainOneDuplicateEvent(node, dedupeFields, expressionAnalyzer, context); } } else { - // TODO - throw new UnsupportedOperationException("Number of duplicate events greater than 1 is not supported"); + if (keepEmpty) { + return retainMultipleDuplicateEventsAndKeepEmpty(node, allowedDuplication, expressionAnalyzer, context); + } else { + return retainMultipleDuplicateEvents(node, allowedDuplication, expressionAnalyzer, context); + } } } - private Expression buildIsNotNullFilterExpression(Dedupe node, CatalystPlanContext context) { - visitFieldList(node.getFields(), context); - Seq isNotNullExpressions = - context.retainAllNamedParseExpressions( - org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); - - Expression isNotNullExpr; - if (isNotNullExpressions.size() == 1) { - isNotNullExpr = isNotNullExpressions.apply(0); - } else { - isNotNullExpr = isNotNullExpressions.reduce( - new scala.Function2() { - @Override - public Expression apply(Expression e1, Expression e2) { - return new org.apache.spark.sql.catalyst.expressions.And(e1, e2); - } - } - ); - } - return isNotNullExpr; - } - - private Expression buildIsNullFilterExpression(Dedupe node, CatalystPlanContext context) { - visitFieldList(node.getFields(), context); - Seq isNullExpressions = - context.retainAllNamedParseExpressions( - org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); - - Expression isNullExpr; - if (isNullExpressions.size() == 1) { - isNullExpr = isNullExpressions.apply(0); - } else { - isNullExpr = isNullExpressions.reduce( - new scala.Function2() { - @Override - public Expression apply(Expression e1, Expression e2) { - return new org.apache.spark.sql.catalyst.expressions.Or(e1, e2); - } - } - ); - } - return isNullExpr; - } - /** * Expression Analyzer. */ - private static class ExpressionAnalyzer extends AbstractNodeVisitor { + public static class ExpressionAnalyzer extends AbstractNodeVisitor { public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { return unresolved.accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java new file mode 100644 index 000000000..0866ca7e9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; +import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Union; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ppl.CatalystPlanContext; +import org.opensearch.sql.ppl.CatalystQueryPlanVisitor.ExpressionAnalyzer; +import scala.collection.Seq; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +public interface DedupeTransformer { + + /** + * | dedup a, b keepempty=true + * Union + * :- Deduplicate ['a, 'b] + * : +- Filter (isnotnull('a) AND isnotnull('b)) + * : +- ... + * : +- UnresolvedRelation + * +- Filter (isnull('a) OR isnull('a)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainOneDuplicateEventAndKeepEmpty( + Dedupe node, + Seq dedupeFields, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + context.apply(p -> { + Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); + + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan left = + new Deduplicate(dedupeFields, + new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + return new Union(seq(left, right), false, false); + }); + return context.getPlan(); + } + + /** + * | dedup a, b keepempty=false + * Deduplicate ['a, 'b] + * +- Filter (isnotnull('a) AND isnotnull('b)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainOneDuplicateEvent( + Dedupe node, + Seq dedupeFields, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + // Todo DeduplicateWithinWatermark in streaming dataset? + return context.apply(p -> new Deduplicate(dedupeFields, p)); + } + + /** + * | dedup 2 a, b keepempty=true + * Union + * :- DataFrameDropColumns('_row_number_) + * : +- Filter ('_row_number_ <= 2) + * : +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * : +- Filter (isnotnull('a) AND isnotnull('b)) + * : +- ... + * : +- UnresolvedRelation + * +- Filter (isnull('a) OR isnull('b)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( + Dedupe node, + Integer allowedDuplication, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + context.apply(p -> { + // Build isnull Filter for right + Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); + + // Build isnotnull Filter + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan isNotNullFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p); + + // Build Window + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); + Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); + Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); + NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); + LogicalPlan window = new org.apache.spark.sql.catalyst.plans.logical.Window( + seq(rowNumber), + partitionSpec, + orderSpec, + isNotNullFilter); + + // Build deduplication Filter ('_row_number_ <= n) + Expression filterExpr = new LessThanOrEqual( + rowNumber.toAttribute(), + new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); + LogicalPlan deduplicationFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, window); + + // Build DataFrameDropColumns('_row_number_) for left + LogicalPlan left = new DataFrameDropColumns(seq(rowNumber.toAttribute()), deduplicationFilter); + + // Build Union + return new Union(seq(left, right), false, false); + }); + return context.getPlan(); + } + + /** + * | dedup 2 a, b keepempty=false + * DataFrameDropColumns('_row_number_) + * +- Filter ('_row_number_ <= n) + * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * +- Filter (isnotnull('a) AND isnotnull('b)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainMultipleDuplicateEvents( + Dedupe node, + Integer allowedDuplication, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + // Build isnotnull Filter + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + + // Build Window + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); + Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); + Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); + NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Window( + seq(rowNumber), + partitionSpec, + orderSpec, p)); + + // Build deduplication Filter ('_row_number_ <= n) + Expression filterExpr = new LessThanOrEqual( + rowNumber.toAttribute(), + new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, p)); + + return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); + } + + private static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); + Seq isNotNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); + + Expression isNotNullExpr; + if (isNotNullExpressions.size() == 1) { + isNotNullExpr = isNotNullExpressions.apply(0); + } else { + isNotNullExpr = isNotNullExpressions.reduce( + (e1, e2) -> new org.apache.spark.sql.catalyst.expressions.And(e1, e2) + ); + } + return isNotNullExpr; + } + + private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); + Seq isNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); + + Expression isNullExpr; + if (isNullExpressions.size() == 1) { + isNullExpr = isNullExpressions.apply(0); + } else { + isNullExpr = isNullExpressions.reduce( + (e1, e2) -> new org.apache.spark.sql.catalyst.expressions.Or(e1, e2) + ); + } + return isNullExpr; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java index c215caec5..0e6ba2a1d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java @@ -5,24 +5,37 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Divide; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Floor; -import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.Multiply; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; +import org.apache.spark.sql.catalyst.expressions.RowNumber; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; import org.apache.spark.sql.catalyst.expressions.TimeWindow; -import org.apache.spark.sql.types.DateType$; -import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.catalyst.expressions.UnboundedPreceding$; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; import org.opensearch.sql.ast.expression.SpanUnit; +import scala.Option; +import scala.collection.Seq; + +import java.util.ArrayList; import static java.lang.String.format; import static org.opensearch.sql.ast.expression.DataType.STRING; import static org.opensearch.sql.ast.expression.SpanUnit.NONE; import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; public interface WindowSpecTransformer { + String ROW_NUMBER_COLUMN_NAME = "_row_number_"; + /** * create a static window buckets based on the given value * @@ -50,4 +63,20 @@ static org.apache.spark.sql.catalyst.expressions.Literal timeLiteral( Expression return new org.apache.spark.sql.catalyst.expressions.Literal( translate(format, STRING), translate(STRING)); } + + static NamedExpression buildRowNumber(Seq partitionSpec, Seq orderSpec) { + WindowExpression rowNumber = new WindowExpression( + new RowNumber(), + new WindowSpecDefinition( + partitionSpec, + orderSpec, + new SpecifiedWindowFrame(RowFrame$.MODULE$, UnboundedPreceding$.MODULE$, CurrentRow$.MODULE$))); + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + rowNumber, + ROW_NUMBER_COLUMN_NAME, + NamedExpression.newExprId(), + seq(new ArrayList()), + Option.empty(), + seq(new ArrayList())); + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupeTranslatorTestSuite.scala similarity index 60% rename from ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala rename to ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupeTranslatorTestSuite.scala index 34cfcbd90..23222c2e3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupeTranslatorTestSuite.scala @@ -7,15 +7,16 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.{SortUtils, WindowSpecTransformer} import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull, IsNull, NamedExpression, Or} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CurrentRow, IsNotNull, IsNull, LessThanOrEqual, Literal, NamedExpression, Or, RowFrame, RowNumber, SortOrder, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Filter, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Deduplicate, Filter, Project, Union, Window} -class PPLLogicalPlanDedupTranslatorTestSuite +class PPLLogicalPlanDedupeTranslatorTestSuite extends SparkFunSuite with PlanTest with LogicalPlanTestUtils @@ -229,40 +230,164 @@ class PPLLogicalPlanDedupTranslatorTestSuite assert(ex.getMessage === "Number of duplicate events must be greater than 0") } - // Todo - ignore("test dedup 2 a") { + test("test dedup 2 a") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a | fields a", false), context) + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val isNotNullFilter = + Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq(SortUtils.sortOrder(UnresolvedAttribute("a"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + val expectedPlan = Project(projectList, dropColumns) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo - ignore("test dedup 2 a, b, c") { + test("test dedup 2 a, b, c") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a, b, c | fields a, b, c", false), context) + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val isNotNullFilter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: UnresolvedAttribute("c") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder( + UnresolvedAttribute("b"), + Ascending) :: SortOrder(UnresolvedAttribute("c"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq( + SortUtils.sortOrder(UnresolvedAttribute("a"), true), + SortUtils.sortOrder(UnresolvedAttribute("b"), true), + SortUtils.sortOrder(UnresolvedAttribute("c"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + val expectedPlan = Project(projectList, dropColumns) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo - ignore("test dedup 2 a keepempty=true") { + test("test dedup 2 a keepempty=true") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a keepempty=true | fields a", false), context) + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val isNotNullFilter = + Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq(SortUtils.sortOrder(UnresolvedAttribute("a"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + + val isNullFilter = Filter(IsNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val union = Union(dropColumns, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo - ignore("test dedup 2 a, b, c keepempty=true") { + test("test dedup 2 a, b, c keepempty=true") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a, b, c keepempty=true | fields a, b, c", false), context) + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val isNotNullFilter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: UnresolvedAttribute("c") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder( + UnresolvedAttribute("b"), + Ascending) :: SortOrder(UnresolvedAttribute("c"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq( + SortUtils.sortOrder(UnresolvedAttribute("a"), true), + SortUtils.sortOrder(UnresolvedAttribute("b"), true), + SortUtils.sortOrder(UnresolvedAttribute("c"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + + val isNullFilter = Filter( + Or( + Or(IsNull(UnresolvedAttribute("a")), IsNull(UnresolvedAttribute("b"))), + IsNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val union = Union(dropColumns, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } test("test dedup 2 a consecutive=true") {