diff --git a/.github/workflows/test-and-build-workflow.yml b/.github/workflows/test-and-build-workflow.yml index 216f8292d..17cbb923c 100644 --- a/.github/workflows/test-and-build-workflow.yml +++ b/.github/workflows/test-and-build-workflow.yml @@ -37,7 +37,7 @@ jobs: - name: Upload test report if: always() # Ensures the artifact is saved even if tests fail - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: test-reports path: target/test-reports # Adjust this path if necessary \ No newline at end of file diff --git a/build.sbt b/build.sbt index 365b88aa3..154f3370c 100644 --- a/build.sbt +++ b/build.sbt @@ -68,8 +68,7 @@ val packagesToShade = Seq( "org.glassfish.json.**", "org.joda.time.**", "org.reactivestreams.**", - "org.yaml.**", - "software.amazon.**" + "org.yaml.**" ) ThisBuild / assemblyShadeRules := Seq( diff --git a/docs/index.md b/docs/index.md index abc801bde..684ba7da6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -394,6 +394,7 @@ User can provide the following options in `WITH` clause of create statement: + `watermark_delay`: a string as time expression for how late data can come and still be processed, e.g. 1 minute, 10 seconds. This is required by auto and incremental refresh on materialized view if it has aggregation in the query. + `output_mode`: a mode string that describes how data will be written to streaming sink. If unspecified, default append mode will be applied. + `index_settings`: a JSON string as index settings for OpenSearch index that will be created. Please follow the format in OpenSearch documentation. If unspecified, default OpenSearch index settings will be applied. ++ `id_expression`: an expression string that generates an ID column to guarantee idempotency when index refresh job restart or any retry attempt during an index refresh. If an empty string is provided, no ID column will be generated. + `extra_options`: a JSON string as extra options that can be passed to Spark streaming source and sink API directly. Use qualified source table name (because there could be multiple) and "sink", e.g. '{"sink": "{key: val}", "table1": {key: val}}' Note that the index option name is case-sensitive. Here is an example: @@ -406,6 +407,7 @@ WITH ( watermark_delay = '1 Second', output_mode = 'complete', index_settings = '{"number_of_shards": 2, "number_of_replicas": 3}', + id_expression = "sha1(concat_ws('\0',startTime,status))", extra_options = '{"spark_catalog.default.alb_logs": {"maxFilesPerTrigger": "1"}}' ) ``` diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 13f382afb..eb5174baf 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -1,5 +1,10 @@ ## Example PPL Queries +#### **AppendCol** +[See additional command details](ppl-appendcol-command.md) +- `source=employees | stats avg(age) as avg_age1 by dept | fields dept, avg_age1 | APPENDCOL [ stats avg(age) as avg_age2 by dept | fields avg_age2 ];` (To display multiple table statistics side by side) +- `source=employees | FIELDS name, dept, age | APPENDCOL OVERRIDE=true [ stats avg(age) as age ];` (When the override option is enabled, fields from the sub-query take precedence over fields in the main query in cases of field name conflicts) + #### **Comment** [See additional command details](ppl-comment.md) - `source=accounts | top gender // finds most common gender of all the accounts` (line comment) @@ -274,7 +279,8 @@ source = table | where ispresent(a) | - `source=accounts | parse email '.+@(?.+)' | stats count() by host` - `source=accounts | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result` - `source=accounts | parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host` -- `source=accounts | parse address '(?\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street` +- `source=accounts | parse address '(?\d+) (?.+)' | eval streetNumberInt = cast(streetNumber as integer) | where streetNumberInt > 500 | sort streetNumberInt | fields streetNumber, street` +- Limitation: [see limitations](ppl-parse-command.md#limitations) #### **view** [See additional command details](ppl-view-command.md) diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index 852f924b8..85534d58e 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -76,6 +76,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`expand commands`](ppl-expand-command.md) + - [`appendcol command`](ppl-appendcol-command.md) + * **Functions** - [`Expressions`](functions/ppl-expressions.md) diff --git a/docs/ppl-lang/ppl-appendcol-command.md b/docs/ppl-lang/ppl-appendcol-command.md new file mode 100644 index 000000000..668bec8ea --- /dev/null +++ b/docs/ppl-lang/ppl-appendcol-command.md @@ -0,0 +1,120 @@ +## PPL `appendcol` command + +### Description +Using `appendcol` command to append the result of a sub-search and attach it alongside with the input search results (The main search). + +### Syntax - APPENDCOL +`APPENDCOL [sub-search]...` + +* : optional boolean field to specify should result from main-result be overwritten in the case of column name conflict. +* sub-search: Executes PPL commands as a secondary search. The sub-search uses the same data specified in the source clause of the main search results as its input. + + +#### Example 1: To append the result of `stats avg(age) as AVG_AGE` into existing search result + +The example append the result of sub-search `stats avg(age) as AVG_AGE` alongside with the main-search. + +PPL query: + + os> source=employees | FIELDS name, dept, age | APPENDCOL [ stats avg(age) as AVG_AGE ]; + fetched rows / total rows = 9/9 + +------+-------------+-----+------------------+ + | name | dept | age | AVG_AGE | + +------+-------------+-----+------------------+ + | Lisa | Sales | 35 | 31.2222222222222 | + | Fred | Engineering | 28 | NULL | + | Paul | Engineering | 23 | NULL | + | Evan | Sales | 38 | NULL | + | Chloe| Engineering | 25 | NULL | + | Tom | Engineering | 33 | NULL | + | Alex | Sales | 33 | NULL | + | Jane | Marketing | 28 | NULL | + | Jeff | Marketing | 38 | NULL | + +------+-------------+-----+------------------+ + + +#### Example 2: To compare multiple stats commands with side by side with appendCol. + +This example demonstrates a common use case: performing multiple statistical calculations and displaying the results side by side in a horizontal layout. + +PPL query: + + os> source=employees | stats avg(age) as avg_age1 by dept | fields dept, avg_age1 | APPENDCOL [ stats avg(age) as avg_age2 by dept | fields avg_age2 ]; + fetched rows / total rows = 3/3 + +-------------+-----------+----------+ + | dept | avg_age1 | avg_age2 | + +-------------+-----------+----------+ + | Engineering | 27.25 | 27.25 | + | Sales | 35.33 | 35.33 | + | Marketing | 33.00 | 33.00 | + +-------------+-----------+----------+ + + +#### Example 3: Append multiple sub-search result + +The example demonstrate multiple APPENCOL commands can be chained to provide one comprehensive view for user. + +PPL query: + + os> source=employees | FIELDS name, dept, age | APPENDCOL [ stats avg(age) as AVG_AGE ] | APPENDCOL [ stats max(age) as MAX_AGE ]; + fetched rows / total rows = 9/9 + +------+-------------+-----+------------------+---------+ + | name | dept | age | AVG_AGE | MAX_AGE | + +------+-------------+-----+------------------+---------+ + | Lisa | Sales------ | 35 | 31.22222222222222| 38 | + | Fred | Engineering | 28 | NULL | NULL | + | Paul | Engineering | 23 | NULL | NULL | + | Evan | Sales------ | 38 | NULL | NULL | + | Chloe| Engineering | 25 | NULL | NULL | + | Tom | Engineering | 33 | NULL | NULL | + | Alex | Sales | 33 | NULL | NULL | + | Jane | Marketing | 28 | NULL | NULL | + | Jeff | Marketing | 38 | NULL | NULL | + +------+-------------+-----+------------------+---------+ + +#### Example 4: Over main-search in the case of column name conflict + +The example demonstrate the usage of `OVERRIDE` option to overwrite the `age` column from the main-search, +when the option is set to true and column with same name `age` present on sub-search. + +PPL query: + + os> source=employees | FIELDS name, dept, age | APPENDCOL OVERRIDE=true [ stats avg(age) as age ]; + fetched rows / total rows = 9/9 + +------+-------------+------------------+ + | name | dept | age | + +------+-------------+------------------+ + | Lisa | Sales------ | 31.22222222222222| + | Fred | Engineering | NULL | + | Paul | Engineering | NULL | + | Evan | Sales------ | NULL | + | Chloe| Engineering | NULL | + | Tom | Engineering | NULL | + | Alex | Sales | NULL | + | Jane | Marketing | NULL | + | Jeff | Marketing | NULL | + +------+-------------+------------------+ + +#### Example 5: AppendCol command with duplicated columns + +The example demonstrate what could happen when conflicted columns exist, with `override` set to false or absent. +In this particular case, average aggregation is being performed over column `age` with group-by `dept`, on main and sub query respectively. +As the result, `dept` and `avg_age1` will be returned by the main query, with `avg_age2` and `dept` for the sub-query, +and take into consideration `override` is absent, duplicated columns won't be dropped, hence all four columns will be displayed as the final result. + +PPL query: + + os> source=employees | stats avg(age) as avg_age1 by dept | APPENDCOL [ stats avg(age) as avg_age2 by dept ]; + fetched rows / total rows = 3/3 + +------------+--------------+------------+--------------+ + | Avg Age 1 | Dept | Avg Age 2 | Dept | + +------------+--------------+------------+--------------+ + | 35.33 | Sales | 35.33 | Sales | + | 27.25 | Engineering | 27.25 | Engineering | + | 33.00 | Marketing | 33.00 | Marketing | + +------------+--------------+------------+--------------+ + + +### Limitation: +When override is set to true, only `FIELDS` and `STATS` commands are allowed as the final clause in a sub-search. +Otherwise, an IllegalStateException with the message `Not Supported operation: APPENDCOL should specify the output fields` will be thrown. diff --git a/docs/ppl-lang/ppl-parse-command.md b/docs/ppl-lang/ppl-parse-command.md index 0e000756e..dbf92ad62 100644 --- a/docs/ppl-lang/ppl-parse-command.md +++ b/docs/ppl-lang/ppl-parse-command.md @@ -58,7 +58,7 @@ The example shows how to sort street numbers that are higher than 500 in ``addre PPL query: - os> source=accounts | parse address '(?\d+) (?.+)' | where cast(streetNumber as int) > 500 | sort num(streetNumber) | fields streetNumber, street ; + os> source=accounts | parse address '(?\d+) (?.+)' | eval streetNumberInt = cast(streetNumber as integer) | where streetNumberInt > 500 | sort streetNumberInt | fields streetNumber, street ; fetched rows / total rows = 3/3 +----------------+----------------+ | streetNumber | street | diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 44ea5188f..300233777 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -11,8 +11,10 @@ import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.FlintJsonHelper._ +import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.flint.datatype.FlintDataType +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.types.StructType /** @@ -62,7 +64,7 @@ trait FlintSparkIndex { def build(spark: SparkSession, df: Option[DataFrame]): DataFrame } -object FlintSparkIndex { +object FlintSparkIndex extends Logging { /** * Interface indicates a Flint index has custom streaming refresh capability other than foreach @@ -117,6 +119,25 @@ object FlintSparkIndex { s"${parts(0)}.${parts(1)}.`${parts.drop(2).mkString(".")}`" } + /** + * Generate an ID column using ID expression provided in the index option. + * + * @param df + * which DataFrame to generate ID column + * @param options + * Flint index options + * @return + * DataFrame with/without ID column + */ + def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { + options.idExpression() match { + case Some(idExpr) if idExpr.nonEmpty => + logInfo(s"Using user-provided ID expression: $idExpr") + df.withColumn(ID_COLUMN, expr(idExpr)) + case _ => df + } + } + /** * Populate environment variables to persist in Flint metadata. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index 9b58a696c..1ad88de6d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -10,7 +10,7 @@ import java.util.{Collections, UUID} import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization -import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} +import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, ID_EXPRESSION, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser @@ -96,6 +96,14 @@ case class FlintSparkIndexOptions(options: Map[String, String]) { */ def indexSettings(): Option[String] = getOptionValue(INDEX_SETTINGS) + /** + * An expression that generates unique value as index data row ID. + * + * @return + * ID expression + */ + def idExpression(): Option[String] = getOptionValue(ID_EXPRESSION) + /** * Extra streaming source options that can be simply passed to DataStreamReader or * Relation.options @@ -187,6 +195,7 @@ object FlintSparkIndexOptions { val WATERMARK_DELAY: OptionName.Value = Value("watermark_delay") val OUTPUT_MODE: OptionName.Value = Value("output_mode") val INDEX_SETTINGS: OptionName.Value = Value("index_settings") + val ID_EXPRESSION: OptionName.Value = Value("id_expression") val EXTRA_OPTIONS: OptionName.Value = Value("extra_options") } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 8748bf874..901c3006c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} @@ -71,10 +71,13 @@ case class FlintSparkCoveringIndex( val job = df.getOrElse(spark.read.table(quotedTableName(tableName))) // Add optional filtering condition - filterCondition - .map(job.where) - .getOrElse(job) - .select(colNames.head, colNames.tail: _*) + val batchDf = + filterCondition + .map(job.where) + .getOrElse(job) + .select(colNames.head, colNames.tail: _*) + + addIdColumn(batchDf, options) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index d5c450e7e..e3b09661a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -13,7 +13,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala` import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} @@ -81,7 +81,8 @@ case class FlintSparkMaterializedView( override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { require(df.isEmpty, "materialized view doesn't support reading from other data frame") - spark.sql(query) + val batchDf = spark.sql(query) + addIdColumn(batchDf, options) } override def buildStream(spark: SparkSession): DataFrame = { @@ -99,7 +100,9 @@ case class FlintSparkMaterializedView( case relation: UnresolvedRelation if !relation.isStreaming => relation.copy(isStreaming = true, options = optionsWithExtra(spark, relation)) } - logicalPlanToDataFrame(spark, streamingPlan) + + val streamingDf = logicalPlanToDataFrame(spark, streamingPlan) + addIdColumn(streamingDf, options) } private def watermark(timeCol: Attribute, child: LogicalPlan) = { diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index 1d301087f..78debda35 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -6,9 +6,12 @@ package org.apache.spark import org.opensearch.flint.spark.FlintSparkExtensions +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.{Alias, CodegenObjectFactoryMode, Expression} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf @@ -68,4 +71,27 @@ trait FlintSuite extends SharedSparkSession { setFlintSparkConf(METADATA_CACHE_WRITE, "false") } } + + /** + * Implicit class to extend DataFrame functionality with additional utilities. + * + * @param df + * the DataFrame to which the additional methods are added + */ + protected implicit class DataFrameExtensions(val df: DataFrame) { + + /** + * Retrieves the ID column expression from the logical plan of the DataFrame, if it exists. + * + * @return + * an `Option` containing the `Expression` for the ID column if present, or `None` otherwise + */ + def idColumn(): Option[Expression] = { + df.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala index d7de6d29b..f752ae68a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala @@ -6,7 +6,6 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._ -import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite @@ -22,6 +21,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { WATERMARK_DELAY.toString shouldBe "watermark_delay" OUTPUT_MODE.toString shouldBe "output_mode" INDEX_SETTINGS.toString shouldBe "index_settings" + ID_EXPRESSION.toString shouldBe "id_expression" EXTRA_OPTIONS.toString shouldBe "extra_options" } @@ -36,6 +36,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { "watermark_delay" -> "30 Seconds", "output_mode" -> "complete", "index_settings" -> """{"number_of_shards": 3}""", + "id_expression" -> """sha1(col("timestamp"))""", "extra_options" -> """ { | "alb_logs": { @@ -55,6 +56,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe Some("30 Seconds") options.outputMode() shouldBe Some("complete") options.indexSettings() shouldBe Some("""{"number_of_shards": 3}""") + options.idExpression() shouldBe Some("""sha1(col("timestamp"))""") options.extraSourceOptions("alb_logs") shouldBe Map("opt1" -> "val1") options.extraSinkOptions() shouldBe Map("opt2" -> "val2", "opt3" -> "val3") } @@ -83,6 +85,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe empty options.outputMode() shouldBe empty options.indexSettings() shouldBe empty + options.idExpression() shouldBe empty options.extraSourceOptions("alb_logs") shouldBe empty options.extraSinkOptions() shouldBe empty options.optionsWithDefault should contain("auto_refresh" -> "false") diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala new file mode 100644 index 000000000..8ec4bec40 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, ID_COLUMN} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { + + test("should add ID column if ID expression is provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10")) + + val resultDf = addIdColumn(df, options) + resultDf.idColumn() shouldBe Some(Add(UnresolvedAttribute("id"), Literal(10))) + checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) + } + + test("should not add ID column if ID expression is not provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = addIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + + test("should not add ID column if ID expression is empty") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = addIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + + test("should generate ID column for various column types") { + val schema = StructType.fromDDL(""" + boolean_col BOOLEAN, + string_col STRING, + long_col LONG, + int_col INT, + double_col DOUBLE, + float_col FLOAT, + timestamp_col TIMESTAMP, + date_col DATE, + struct_col STRUCT + """) + val data = Seq( + Row( + true, + "Alice", + 100L, + 10, + 10.5, + 3.14f, + java.sql.Timestamp.valueOf("2024-01-01 10:00:00"), + java.sql.Date.valueOf("2024-01-01"), + Row("sub1", 1))) + + val aggregatedDf = spark + .createDataFrame(sparkContext.parallelize(data), schema) + .groupBy( + "boolean_col", + "string_col", + "long_col", + "int_col", + "double_col", + "float_col", + "timestamp_col", + "date_col", + "struct_col", + "struct_col.subfield2") + .count() + val options = FlintSparkIndexOptions(Map("id_expression" -> + "sha1(concat_ws('\0',boolean_col,string_col,long_col,int_col,double_col,float_col,timestamp_col,date_col,to_json(struct_col),struct_col.subfield2))")) + + val resultDf = addIdColumn(aggregatedDf, options) + resultDf.idColumn() shouldBe Some( + UnresolvedFunction( + "sha1", + Seq(UnresolvedFunction( + "concat_ws", + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("boolean_col")), + UnresolvedAttribute(Seq("string_col")), + UnresolvedAttribute(Seq("long_col")), + UnresolvedAttribute(Seq("int_col")), + UnresolvedAttribute(Seq("double_col")), + UnresolvedAttribute(Seq("float_col")), + UnresolvedAttribute(Seq("timestamp_col")), + UnresolvedAttribute(Seq("date_col")), + UnresolvedFunction( + "to_json", + Seq(UnresolvedAttribute(Seq("struct_col"))), + isDistinct = false), + UnresolvedAttribute(Seq("struct_col", "subfield2"))), + isDistinct = false)), + isDistinct = false)) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 1cce47d1a..d23ad875e 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -5,13 +5,20 @@ package org.opensearch.flint.spark.covering +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndexOptions import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.Project class FlintSparkCoveringIndexSuite extends FlintSuite { + private val testTable = "spark_catalog.default.ci_test" + test("get covering index name") { val index = new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) @@ -54,4 +61,34 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { new FlintSparkCoveringIndex("ci", "default.test", Map.empty) } } + + test("build batch with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = + FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = FlintSparkIndexOptions(Map("id_expression" -> "name"))) + + val batchDf = index.build(spark, None) + batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("name"))) + } + } + + test("build stream with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) + + val streamDf = index.build(spark, Some(spark.table(testTable))) + streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("name"))) + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 78d2eb09e..838eddf21 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -15,7 +15,7 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.Attribute @@ -36,6 +36,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testMvName = "spark_catalog.default.mv" val testQuery = "SELECT 1" + override def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + } + + override def afterAll(): Unit = { + sql(s"DROP TABLE $testTable") + super.afterAll() + } + test("get mv name") { val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" @@ -174,19 +184,59 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build stream should fail if there is aggregation but no windowing function") { - val testTable = "mv_build_test" - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), + Map.empty) - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", - Array(testTable), - Map.empty) + the[IllegalStateException] thrownBy + mv.buildStream(spark) + } - the[IllegalStateException] thrownBy - mv.buildStream(spark) - } + test("build batch with ID expression option") { + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> "time"))) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) + } + + test("build batch should not have ID column if not provided") { + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe None + } + + test("build stream with ID expression option") { + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "time"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) + } + + test("build stream should not have ID column if not provided") { + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe None } private def withAggregateMaterializedView( @@ -194,19 +244,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { sourceTables: Array[String], options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView( - testMvName, - query, - sourceTables, - Map.empty, - FlintSparkIndexOptions(options)) - - val actualPlan = mv.buildStream(spark).queryExecution.logical - codeBlock(actualPlan) - } + val mv = + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) + + val actualPlan = mv.buildStream(spark).queryExecution.logical + codeBlock(actualPlan) } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index aac06a2c1..0791f9b7a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -28,19 +28,23 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { private val testTable = "spark_catalog.default.covering_sql_test" private val testIndex = "name_and_age" private val testFlintIndex = getFlintIndexName(testIndex, testTable) + private val testTimeSeriesTable = "spark_catalog.default.covering_sql_ts_test" + private val testFlintTimeSeriesIndex = getFlintIndexName(testIndex, testTimeSeriesTable) override def beforeEach(): Unit = { super.beforeEach() createPartitionedAddressTable(testTable) + createTimeSeriesTable(testTimeSeriesTable) } override def afterEach(): Unit = { super.afterEach() // Delete all test indices - deleteTestIndex(testFlintIndex) + deleteTestIndex(testFlintIndex, testFlintTimeSeriesIndex) sql(s"DROP TABLE $testTable") + sql(s"DROP TABLE $testTimeSeriesTable") } test("create covering index with auto refresh") { @@ -86,6 +90,41 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { } } + test("create covering index with auto refresh and ID expression") { + sql(s""" + | CREATE INDEX $testIndex ON $testTimeSeriesTable + | (time, age, address) + | WITH ( + | auto_refresh = true, + | id_expression = 'address' + | ) + |""".stripMargin) + + val job = spark.streams.active.find(_.name == testFlintTimeSeriesIndex) + awaitStreamingComplete(job.get.id.toString) + + val indexData = flint.queryIndex(testFlintTimeSeriesIndex) + indexData.count() shouldBe 3 // only 3 rows left due to same ID + } + + test("create covering index with full refresh and ID expression") { + sql(s""" + | CREATE INDEX $testIndex ON $testTimeSeriesTable + | (time, age, address) + | WITH ( + | id_expression = 'address' + | ) + |""".stripMargin) + sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable") + + val indexData = flint.queryIndex(testFlintTimeSeriesIndex) + indexData.count() shouldBe 3 // only 3 rows left due to same ID + + // Rerun should not generate duplicate data + sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable") + indexData.count() shouldBe 3 + } + test("create covering index with index settings") { sql(s""" | CREATE INDEX $testIndex ON $testTable ( name ) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index bf5e6309e..7dcd83897 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -154,6 +154,50 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { } } + test("create materialized view with auto refresh and ID expression") { + withTempDir { checkpointDir => + sql(s""" + | CREATE MATERIALIZED VIEW $testMvName + | AS $testQuery + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}', + | watermark_delay = '1 Second', + | id_expression = "sha1(concat_ws('\0',startTime))" + | ) + |""".stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testFlintIndex) + job shouldBe defined + failAfter(streamingTimeout) { + job.get.processAllAvailable() + } + + flint.queryIndex(testFlintIndex).count() shouldBe 3 + } + } + + test("create materialized view with full refresh and ID expression") { + sql(s""" + | CREATE MATERIALIZED VIEW $testMvName + | AS $testQuery + | WITH ( + | id_expression = 'count' + | ) + |""".stripMargin) + + sql(s"REFRESH MATERIALIZED VIEW $testMvName") + + // 2 rows missing due to ID conflict intentionally + val indexData = flint.queryIndex(testFlintIndex) + indexData.count() shouldBe 2 + + // Rerun should not generate duplicate data + sql(s"REFRESH MATERIALIZED VIEW $testMvName") + indexData.count() shouldBe 2 + } + test("create materialized view with index settings") { sql(s""" | CREATE MATERIALIZED VIEW $testMvName diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAppendColITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAppendColITSuite.scala new file mode 100644 index 000000000..1e0e20f8b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAppendColITSuite.scala @@ -0,0 +1,696 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.sql.ppl.utils.SortUtils + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Ascending$, CurrentRow, EqualTo, Literal, RowFrame, RowNumber, SortOrder, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.logical.{Project, _} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.DataTypes + +class FlintSparkPPLAppendColITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + private val ROW_NUMBER_AGGREGATION = Alias( + WindowExpression( + RowNumber(), + WindowSpecDefinition( + Nil, + SortUtils.sortOrder(Literal("1"), false) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), + "_row_number_")() + + private val COUNT_STAR = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")() + + private val AGE_ALIAS = Alias(UnresolvedAttribute("age"), "age")() + + private val COUNTRY_ALIAS = Alias(UnresolvedAttribute("country"), "country")() + + private val RELATION_TEST_TABLE = UnresolvedRelation( + Seq("spark_catalog", "default", "flint_ppl_test")) + + private val T12_JOIN_CONDITION = + EqualTo( + UnresolvedAttribute("APPENDCOL_T1._row_number_"), + UnresolvedAttribute("APPENDCOL_T2._row_number_")) + + private val T12_COLUMNS_SEQ = + Seq( + UnresolvedAttribute("APPENDCOL_T1._row_number_"), + UnresolvedAttribute("APPENDCOL_T2._row_number_")) + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + /** + * The baseline test-case to make sure APPENDCOL( ) function works, when no transformation + * present on the main search, after the search command. + */ + test("test AppendCol with NO transformation on main") { + val frame = sql(s""" + | source = $testTable | APPENDCOL [stats count() by age] + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "count()", "age"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", "USA", 2023, 4, 1, 70), + Row("Hello", 30, "New York", "USA", 2023, 4, 1, 30), + Row("John", 25, "Ontario", "Canada", 2023, 4, 1, 25), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 1, 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#7, *] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project(Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), RELATION_TEST_TABLE)) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + +- 'Aggregate ['age AS age#9], ['COUNT(*) AS count()#8, 'age AS age#10] + +- 'UnresolvedRelation [relation], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_TEST_TABLE))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + /** + * To simulate the use-case when user attempt to attach an APPENDCOL command on a well + * established main search. + */ + test("test AppendCol with transformation on main-search") { + val frame = sql(s""" + | source = $testTable | FIELDS name, age, state | APPENDCOL [stats count() by age] + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "state", "count()", "age"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", 1, 70), + Row("Hello", 30, "New York", 1, 30), + Row("John", 25, "Ontario", 1, 25), + Row("Jane", 20, "Quebec", 1, 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + : +- 'Project ['name, 'age, 'state] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("state")), + RELATION_TEST_TABLE))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + +- 'Aggregate ['age AS age#9], ['COUNT(*) AS count()#8, 'age AS age#10] + +- 'UnresolvedRelation [relation], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_TEST_TABLE))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + /** + * To simulate the situation when multiple PPL commands being applied on the sub-search. + */ + test("test AppendCol with chained sub-search") { + val frame = sql(s""" + | source = $testTable | FIELDS name, age, state | APPENDCOL [ stats count() by age | eval m = 1 | FIELDS -m ] + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "state", "count()", "age"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "California", 1, 70), + Row("Hello", 30, "New York", 1, 30), + Row("John", 25, "Ontario", 1, 25), + Row("Jane", 20, "Quebec", 1, 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + : +- 'Project ['age, 'dept, 'salary] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("state")), + RELATION_TEST_TABLE))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#432, *] + +- 'DataFrameDropColumns ['m] + +- 'Project [*, 1 AS m#430] + +- 'Aggregate ['age AS age#429], ['COUNT(*) AS count()#428, 'age AS age#429] + +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + DataFrameDropColumns( + Seq(UnresolvedAttribute("m")), + Project( + Seq(UnresolvedStar(None), Alias(Literal(1), "m")()), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_TEST_TABLE))))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + /** + * The use-case when user attempt to chain multiple APPENCOL command in a PPL, this is a common + * use case, when user prefer to show the statistic report alongside with the dataset. + */ + test("test multiple AppendCol clauses") { + val frame = sql(s""" + | source = $testTable | FIELDS name, age | APPENDCOL [ stats count() by age | eval m = 1 | FIELDS -m ] | APPENDCOL [FIELDS state] + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "count()", "age", "state"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, 1, 70, "California"), + Row("Hello", 30, 1, 30, "New York"), + Row("John", 25, 1, 25, "Ontario"), + Row("Jane", 20, 1, 20, "Quebec")) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#544, *] + : +- 'Project ['name, 'age] + : +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val mainSearch = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + RELATION_TEST_TABLE))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#432, *] + +- 'DataFrameDropColumns ['m] + +- 'Project [*, 1 AS m#430] + +- 'Aggregate ['age AS age#429], ['COUNT(*) AS count()#428, 'age AS age#429] + +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val firstAppenCol = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + DataFrameDropColumns( + Seq(UnresolvedAttribute("m")), + Project( + Seq(UnresolvedStar(None), Alias(Literal(1), "m")()), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_TEST_TABLE))))) + + val joinWithFirstAppendCol = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(mainSearch, firstAppenCol, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE)))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#553, *] + +- 'Project ['dept] + +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val secondAppendCol = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project(Seq(UnresolvedAttribute("state")), RELATION_TEST_TABLE))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join( + joinWithFirstAppendCol, + secondAppendCol, + FullOuter, + Some(T12_JOIN_CONDITION), + JoinHint.NONE))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + /** + * To simulate the use-case when column `age` present on both main and sub search, with option + * OVERRIDE=true. + */ + test("test AppendCol with OVERRIDE option") { + val frame = sql(s""" + | source = $testTable | FIELDS name, age, state | APPENDCOL OVERRIDE=true [stats count() as age] + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "state", "age"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + + /* + The sub-search result `APPENDCOL OVERRIDE=true [stats count() as age]` will be attached alongside with first row of main-search, + however given the non-deterministic natural of nature order, we cannot guarantee which specific data row will be returned from the primary search query. + Hence, only assert sub-search position but skipping the table content comparison. + */ + assert(results(0).get(2) == 4) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + : +- 'Project ['name, 'age, 'state] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("state")), + RELATION_TEST_TABLE))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#216, *] + +- 'Aggregate ['COUNT(*) AS age#240] + +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Aggregate( + Nil, + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "age")()), + RELATION_TEST_TABLE))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ :+ UnresolvedAttribute("APPENDCOL_T1.age"), + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + + // @formatter:off + /** + * In the case that sub-query return more result rows than main-query, null will be used for padding, + * expected logical plan: + * 'Project [*] + * +- 'DataFrameDropColumns ['APPENDCOL_T1._row_number_, 'APPENDCOL_T2._row_number_] + * +- 'Join FullOuter, ('APPENDCOL_T1._row_number_ = 'APPENDCOL_T2._row_number_) + * :- 'SubqueryAlias APPENDCOL_T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#225, *] + * : +- 'GlobalLimit 1 + * : +- 'LocalLimit 1 + * : +- 'Project ['name, 'age] + * : +- 'Sort ['age ASC NULLS FIRST], true + * : +- 'UnresolvedRelation [flint_ppl_test], [], false + * +- 'SubqueryAlias APPENDCOL_T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#227, *] + * +- 'Project ['state] + * +- 'Sort ['age ASC NULLS FIRST], true + * +- 'UnresolvedRelation [flint_ppl_test], [], false + * + */ + // @formatter:on + test("test AppendCol with Null on main-query") { + val frame = sql(s""" + | source = $testTable | sort age | FIELDS name, age | head 1 | APPENDCOL [sort age | FIELDS state ]; + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "state"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec"), + Row(null, null, "Ontario"), + Row(null, null, "New York"), + Row(null, null, "California")) + // Compare the results + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + /* + :- 'SubqueryAlias APPENDCOL_T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#225, *] + : +- 'GlobalLimit 1 + : +- 'LocalLimit 1 + : +- 'Project ['name, 'age] + : +- 'Sort ['age ASC NULLS FIRST], true + : +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Limit( + Literal(1, DataTypes.IntegerType), + Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + Sort( + SortUtils.sortOrder(UnresolvedAttribute("age"), true) :: Nil, + true, + RELATION_TEST_TABLE))))) + + /* + +- 'SubqueryAlias APPENDCOL_T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#244, *] + +- 'Project ['state] + +- 'Sort ['age ASC NULLS FIRST], true + +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("state")), + Sort( + SortUtils.sortOrder(UnresolvedAttribute("age"), true) :: Nil, + true, + RELATION_TEST_TABLE)))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // @formatter:off + /** + * In the case that sub-query return more result rows than main-query, null will be used for padding, + * expected logical plan: + * 'Project [*] + * +- 'DataFrameDropColumns ['APPENDCOL_T1._row_number_, 'APPENDCOL_T2._row_number_] + * +- 'Join FullOuter, ('APPENDCOL_T1._row_number_ = 'APPENDCOL_T2._row_number_) + * :- 'SubqueryAlias APPENDCOL_T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#289, *] + * : +- 'Project ['name, 'age] + * : +- 'Sort ['age ASC NULLS FIRST], true + * : +- 'UnresolvedRelation [flint_ppl_test], [], false + * +- 'SubqueryAlias APPENDCOL_T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#291, *] + * +- 'GlobalLimit 1 + * +- 'LocalLimit 1 + * +- 'Project ['state] + * +- 'Sort ['age ASC NULLS FIRST], true + * +- 'UnresolvedRelation [flint_ppl_test], [], false + * + */ + // @formatter:on + test("test AppendCol with Null on sub-query") { + val frame = sql(s""" + | source = $testTable | sort age | FIELDS name, age | APPENDCOL [sort age | FIELDS state | head 1 ]; + | """.stripMargin) + + assert(frame.columns.sameElements(Array("name", "age", "state"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec"), + Row("John", 25, null), + Row("Hello", 30, null), + Row("Jake", 70, null)) + // Compare the results + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + /* + * 'Project [*] + * +- 'DataFrameDropColumns ['APPENDCOL_T1._row_number_, 'APPENDCOL_T2._row_number_] + * +- 'Join FullOuter, ('APPENDCOL_T1._row_number_ = 'APPENDCOL_T2._row_number_) + * :- 'SubqueryAlias APPENDCOL_T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#289, *] + * : +- 'Project ['name, 'age] + * : +- 'Sort ['age ASC NULLS FIRST], true + * : +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + Sort( + SortUtils.sortOrder(UnresolvedAttribute("age"), true) :: Nil, + true, + RELATION_TEST_TABLE)))) + + /* + * +- 'SubqueryAlias APPENDCOL_T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#291, *] + * +- 'GlobalLimit 1 + * +- 'LocalLimit 1 + * +- 'Project ['state] + * +- 'Sort ['age ASC NULLS FIRST], true + * +- 'UnresolvedRelation [flint_ppl_test], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Limit( + Literal(1, DataTypes.IntegerType), + Project( + Seq(UnresolvedAttribute("state")), + Sort( + SortUtils.sortOrder(UnresolvedAttribute("age"), true) :: Nil, + true, + RELATION_TEST_TABLE))))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // @formatter:off + /** + * 'Project [*] + * +- 'DataFrameDropColumns ['APPENDCOL_T1._row_number_, 'APPENDCOL_T2._row_number_] + * +- 'Join FullOuter, ('APPENDCOL_T1._row_number_ = 'APPENDCOL_T2._row_number_) + * :- 'SubqueryAlias APPENDCOL_T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#977, *] + * : +- 'Project ['country, 'avg_age1] + * : +- 'Aggregate ['country AS country#975], ['AVG('age) AS avg_age1#974, 'country AS country#975] + * : +- 'UnresolvedRelation [testTable], [], false + * +- 'SubqueryAlias APPENDCOL_T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#981, *] + * +- 'Project ['avg_age2] + * +- 'Aggregate ['country AS country#979], ['AVG('age) AS avg_age2#978, 'country AS country#979] + * +- 'UnresolvedRelation [testTable], [], false + */ + // @formatter:on + test("test AppendCol with multiple stats commands") { + val frame = sql(s""" + | source = $testTable | stats avg(age) as avg_age1 by country | fields country, avg_age1 | appendcol [stats avg(age) as avg_age2 by country | fields avg_age2]; + | """.stripMargin) + + assert(frame.columns.sameElements(Array("country", "avg_age1", "avg_age2"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row("USA", 50.0, 50.0), Row("Canada", 22.5, 22.5)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + /* + * :- 'SubqueryAlias APPENDCOL_T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#977, *] + * : +- 'Project ['country, 'avg_age1] + * : +- 'Aggregate ['country AS country#975], ['AVG('age) AS avg_age1#974, 'country AS country#975] + * : +- 'UnresolvedRelation [testTable], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("country"), UnresolvedAttribute("avg_age1")), + Aggregate( + COUNTRY_ALIAS :: Nil, + Seq( + Alias( + UnresolvedFunction( + Seq("AVG"), + Seq(UnresolvedAttribute("age")), + isDistinct = false), + "avg_age1")(), + COUNTRY_ALIAS), + RELATION_TEST_TABLE)))) + + /* + * +- 'SubqueryAlias APPENDCOL_T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#981, *] + * +- 'Project ['avg_age2] + * +- 'Aggregate ['country AS country#979], ['AVG('age) AS avg_age2#978, 'country AS country#979] + * +- 'UnresolvedRelation [testTable], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("avg_age2")), + Aggregate( + COUNTRY_ALIAS :: Nil, + Seq( + Alias( + UnresolvedFunction( + Seq("AVG"), + Seq(UnresolvedAttribute("age")), + isDistinct = false), + "avg_age2")(), + COUNTRY_ALIAS), + RELATION_TEST_TABLE)))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala index e69999a8e..5693f4df1 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala @@ -5,15 +5,12 @@ package org.opensearch.flint.spark.ppl -import scala.reflect.internal.Reporter.Count - -import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq - -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Cast, Descending, GreaterThan, Literal, NullsFirst, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.IntegerType class FlintSparkPPLParseITSuite extends QueryTest @@ -214,10 +211,16 @@ class FlintSparkPPLParseITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } - test("test parse email & host expressions including cast and sort commands") { - val frame = sql(s""" - | source = $testTable| parse street_address '(?\\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street - | """.stripMargin) + test("test parse street number & street expressions including cast and sort commands") { + + // TODO #963: Implement 'num', 'str', and 'ip' sort syntax + val query = s"source = $testTable | " + + "parse street_address '(?\\d+) (?.+)' | " + + "eval streetNumberInt = cast(streetNumber as integer) | " + + "where streetNumberInt > 500 | " + + "sort streetNumberInt | " + + "fields streetNumber, street" + val frame = sql(query) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results @@ -233,36 +236,36 @@ class FlintSparkPPLParseITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - val addressAttribute = UnresolvedAttribute("street_address") + val streetAddressAttribute = UnresolvedAttribute("street_address") val streetNumberAttribute = UnresolvedAttribute("streetNumber") val streetAttribute = UnresolvedAttribute("street") + val streetNumberIntAttribute = UnresolvedAttribute("streetNumberInt") - val streetNumberExpression = Alias( - RegExpExtract( - addressAttribute, - Literal("(?\\d+) (?.+)"), - Literal("1")), - "streetNumber")() + val regexLiteral = Literal("(?\\d+) (?.+)") + val streetNumberExpression = + Alias(RegExpExtract(streetAddressAttribute, regexLiteral, Literal("1")), "streetNumber")() + val streetExpression = + Alias(RegExpExtract(streetAddressAttribute, regexLiteral, Literal("2")), "street")() - val streetExpression = Alias( - RegExpExtract( - addressAttribute, - Literal("(?\\d+) (?.+)"), - Literal("2")), - "street")() + val castExpression = Cast(streetNumberAttribute, IntegerType) val expectedPlan = Project( Seq(streetNumberAttribute, streetAttribute), Sort( - Seq(SortOrder(streetNumberAttribute, Ascending, NullsFirst, Seq.empty)), + Seq(SortOrder(streetNumberIntAttribute, Ascending, NullsFirst, Seq.empty)), global = true, Filter( - GreaterThan(streetNumberAttribute, Literal(500)), + GreaterThan(streetNumberIntAttribute, Literal(500)), Project( - Seq(addressAttribute, streetNumberExpression, streetExpression, UnresolvedStar(None)), - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))) + Seq(UnresolvedStar(None), Alias(castExpression, "streetNumberInt")()), + Project( + Seq( + streetAddressAttribute, + streetNumberExpression, + streetExpression, + UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))))) assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 8b762dffa..1cb5e90d8 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -43,6 +43,7 @@ FILLNULL: 'FILLNULL'; EXPAND: 'EXPAND'; FLATTEN: 'FLATTEN'; TRENDLINE: 'TRENDLINE'; +APPENDCOL: 'APPENDCOL'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -87,14 +88,14 @@ WITH: 'WITH'; PARQUET: 'PARQUET'; CSV: 'CSV'; TEXT: 'TEXT'; - -// FIELD KEYWORDS + +// SORT FIELD KEYWORDS +// TODO #963: Implement 'num', 'str', and 'ip' sort syntax AUTO: 'AUTO'; STR: 'STR'; IP: 'IP'; NUM: 'NUM'; - // FIELDSUMMARY keywords FIELDSUMMARY: 'FIELDSUMMARY'; INCLUDEFIELDS: 'INCLUDEFIELDS'; @@ -104,6 +105,9 @@ NULLS: 'NULLS'; SMA: 'SMA'; WMA: 'WMA'; +// APPENDCOL options +OVERRIDE: 'OVERRIDE'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index c4e30f0d3..3a2694c1a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -56,6 +56,7 @@ commands | flattenCommand | expandCommand | trendlineCommand + | appendcolCommand ; commandName @@ -90,6 +91,7 @@ commandName | FIELDSUMMARY | FLATTEN | TRENDLINE + | APPENDCOL ; searchCommand @@ -298,6 +300,10 @@ trendlineType | WMA ; +appendcolCommand + : APPENDCOL (OVERRIDE EQUAL override = booleanLiteral)? LT_SQR_PRTHS commands (PIPE commands)* RT_SQR_PRTHS + ; + kmeansCommand : KMEANS (kmeansParameter)* ; @@ -552,6 +558,8 @@ sortField sortFieldExpression : fieldExpression + + // TODO #963: Implement 'num', 'str', and 'ip' sort syntax | AUTO LT_PRTHS fieldExpression RT_PRTHS | STR LT_PRTHS fieldExpression RT_PRTHS | IP LT_PRTHS fieldExpression RT_PRTHS @@ -1132,10 +1140,6 @@ keywordsCanBeId | INDEX | DESC | DATASOURCES - | AUTO - | STR - | IP - | NUM | FROM | PATTERN | NEW_FIELD @@ -1219,4 +1223,9 @@ keywordsCanBeId | BETWEEN | CIDRMATCH | trendlineType + // SORT FIELD KEYWORDS + | AUTO + | STR + | IP + | NUM ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 31841430c..7a15b47bc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -122,6 +122,10 @@ public T visitTrendline(Trendline node, C context) { return visitChildren(node, context); } + public T visitAppendCol(AppendCol node, C context) { + return visitChildren(node, context); + } + public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/AppendCol.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/AppendCol.java new file mode 100644 index 000000000..421821739 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/AppendCol.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +import java.util.List; + +/** + * A composite object which store the subQuery along with some more ad-hoc option like override + */ +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = false) +@AllArgsConstructor +public class AppendCol extends UnresolvedPlan { + + private final boolean override; + + private final UnresolvedPlan subSearch; + + private UnresolvedPlan child; + + public AppendCol(UnresolvedPlan subSearch, boolean override) { + this.override = override; + this.subSearch = subSearch; + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitAppendCol(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 22beab605..8d2248e73 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -18,6 +18,7 @@ import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Generate; @@ -48,6 +49,7 @@ import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.AppendCol; import org.opensearch.sql.ast.tree.Correlation; import org.opensearch.sql.ast.tree.CountedAggregation; import org.opensearch.sql.ast.tree.Dedupe; @@ -74,6 +76,7 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.ppl.utils.AppendColCatalystUtils; import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator; import org.opensearch.sql.ppl.utils.ParseTransformer; @@ -94,6 +97,15 @@ import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.TABLE_LHS; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.TABLE_RHS; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.appendRelationClause; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.combineQueriesWithJoin; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.getOverridedList; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.getRowNumStarProjection; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.isValidOverrideList; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.t1Attr; +import static org.opensearch.sql.ppl.utils.AppendColCatalystUtils.t2Attr; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; @@ -264,6 +276,46 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); } + @Override + public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) { + + // Apply an additional projection layer on main-search to provide natural order. + LogicalPlan mainSearch = visitFirstChild(node, context); + var mainSearchWithRowNumber = getRowNumStarProjection(context, mainSearch, TABLE_LHS); + context.withSubqueryAlias(mainSearchWithRowNumber); + + // Duplicate the relation clause from main-search to sub-search. + final Node subSearchNode = appendRelationClause(node.getSubSearch(), context.getRelations()); + + context.apply(left -> { + + // Apply an additional projection layer on sub-search to provide natural order. + LogicalPlan subSearch = subSearchNode.accept(this, context); + var subSearchWithRowNumber = getRowNumStarProjection(context, subSearch, TABLE_RHS); + + context.withSubqueryAlias(subSearchWithRowNumber); + context.retainAllNamedParseExpressions(p -> p); + context.retainAllPlans(p -> p); + + // Join both Main and Sub search with _ROW_NUMBER_ column + List fieldsToRemove = new ArrayList<>(List.of(t1Attr, t2Attr)); + // Remove the APPEND_ID and duplicated field on T1 if override option present. + if (node.isOverride()) { + final List attrToOverride = getOverridedList(subSearch, TABLE_LHS); + if (isValidOverrideList(attrToOverride)) { + fieldsToRemove.addAll(attrToOverride); + } else { + throw new IllegalStateException("Not Supported operation: " + + "APPENDCOL should specify the output fields"); + } + } + return new DataFrameDropColumns(seq(fieldsToRemove), + combineQueriesWithJoin(mainSearchWithRowNumber, subSearchWithRowNumber)); + }); + return context.getPlan(); + } + + @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { visitFirstChild(node, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index bfc45f50e..37f41bc58 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -428,6 +428,17 @@ public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommand .orElse(new Trendline(Optional.empty(), trendlineComputations)); } + @Override + public UnresolvedPlan visitAppendcolCommand(OpenSearchPPLParser.AppendcolCommandContext ctx) { + final Optional pplCmd = ctx.commands().stream() + .map(this::visit) + .reduce((r, e) -> e.attach(r)); + final boolean override = (ctx.override != null && + Boolean.parseBoolean(ctx.override.getText())); + // ANTLR parser check guarantee pplCmd won't be null. + return new AppendCol(pplCmd.get(), override); + } + private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParser.TrendlineClauseContext ctx) { int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); if (numberOfDataPoints < 1) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index da1fa40aa..45253aec0 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -47,8 +47,6 @@ import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; -import org.opensearch.sql.ast.tree.Trendline; -import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator; @@ -188,6 +186,8 @@ public UnresolvedExpression visitWcFieldExpression(OpenSearchPPLParser.WcFieldEx @Override public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ctx) { + + // TODO #963: Implement 'num', 'str', and 'ip' sort syntax return new Field((QualifiedName) visit(ctx.sortFieldExpression().fieldExpression().qualifiedName()), ArgumentFactory.getArgumentList(ctx)); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AppendColCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AppendColCatalystUtils.java new file mode 100644 index 000000000..a8373867c --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AppendColCatalystUtils.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias; +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.unsafe.types.UTF8String; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.Option; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; +import static scala.collection.JavaConverters.seqAsJavaList; + +/** + * Util class to facilitate the logical plan composition for APPENDCOL command. + */ +public interface AppendColCatalystUtils { + + String TABLE_LHS = "APPENDCOL_T1"; + String TABLE_RHS = "APPENDCOL_T2"; + UnresolvedAttribute t1Attr = new UnresolvedAttribute(seq(TABLE_LHS, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)); + UnresolvedAttribute t2Attr = new UnresolvedAttribute(seq(TABLE_RHS, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)); + + + /** + * Responsible to traverse given subSearch Node till the last child, then append the Relation clause, + * in order to specify the data source || index for the subSearch. + * @param subSearch User provided sub-search from APPENDCOL command. + * @param relation Relation clause which represent the dataSource that this sub-search execute upon. + */ + static Node appendRelationClause(Node subSearch, List relation) { + final List unresolvedExpressionList = relation.stream() + .map(r -> { + UnresolvedRelation unresolvedRelation = (UnresolvedRelation) r; + List multipartId = seqAsJavaList(unresolvedRelation.multipartIdentifier()); + return (UnresolvedExpression) new QualifiedName(multipartId); + }) + // To avoid stack overflow in the case of chained AppendCol. + .distinct() + .collect(Collectors.toList()); + final Relation table = new Relation(unresolvedExpressionList); + final Node head = subSearch; + while (subSearch != null) { + try { + subSearch = subSearch.getChild().get(0); + } catch (NullPointerException ex) { + ((UnresolvedPlan) subSearch).attach(table); + break; + } + } + return head; + } + + + /** + * Util method extract output fields from given LogicalPlan instance in non-recursive manner, + * and return null in the case of non-supported LogicalPlan. + * @param lp LogicalPlan instance to extract the projection fields from. + * @param tableName the table || schema name being appended as part of the returned fields. + * @return A list of Expression instances with alternated tableName || Schema information. + */ + static List getOverridedList(LogicalPlan lp, String tableName) { + // Extract the output from supported LogicalPlan type. + if (lp instanceof Project || lp instanceof Aggregate) { + return seqAsJavaList(lp.output()).stream() + .map(attr -> new UnresolvedAttribute(seq(tableName, attr.name()))) + .collect(Collectors.toList()); + } + return null; + } + + /** + * To perform check against the given list of expression to override. + * @param attrToOverride List of Expression instances to be checked. + * @return boolean value to indicate does the incoming list is good for DFDropColumns action. + */ + static boolean isValidOverrideList (List attrToOverride) { + return attrToOverride != null && + !attrToOverride.isEmpty() && + attrToOverride.stream().noneMatch(UnresolvedStar.class::isInstance); + } + + /** + * Helper method to first add an additional projection clause to provide row_number, then wrap it SubqueryAlias and return. + * @param context Context object of the current Parser. + * @param lp The Logical Plan instance which contains the query. + * @param alias The name of the Alias clause. + * @return A subqueryAlias instance which has row_number for natural ordering purpose. + */ + static SubqueryAlias getRowNumStarProjection(CatalystPlanContext context, LogicalPlan lp, String alias) { + final SortOrder sortOrder = SortUtils.sortOrder( + new Literal( + UTF8String.fromString("1"), DataTypes.StringType), false); + + final NamedExpression appendCol = WindowSpecTransformer.buildRowNumber(seq(), seq(sortOrder)); + final List projectList = (context.getNamedParseExpressions().isEmpty()) + ? List.of(appendCol, new UnresolvedStar(Option.empty())) + : List.of(appendCol); + + final LogicalPlan lpWithProjection = new Project(seq( + projectList), lp); + return SubqueryAlias$.MODULE$.apply(alias, lpWithProjection); + } + + /** + * Util method to return a joint Logical plan with given SubqueryAlias(es). + * @param lhs Left hand side query (main-query) for the AppendCol logical plan. + * @param rhs Right hand side query (sub-query) for the AppendCol logical plan. + * @return A joint logical plan which combine the given SubqueryAlias(es). + */ + static LogicalPlan combineQueriesWithJoin(SubqueryAlias lhs, SubqueryAlias rhs) { + return join( + lhs, rhs, + Join.JoinType.FULL, + Optional.of(new EqualTo(t1Attr, t2Attr)), + new Join.JoinHint()); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAppendColCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAppendColCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..747d9a305 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAppendColCommandTranslatorTestSuite.scala @@ -0,0 +1,410 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.SortUtils +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, EqualTo, Literal, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.{FullOuter, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ + +class PPLLogicalPlanAppendColCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + private val ROW_NUMBER_AGGREGATION = Alias( + WindowExpression( + RowNumber(), + WindowSpecDefinition( + Nil, + SortUtils.sortOrder(Literal("1"), false) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), + "_row_number_")() + + private val COUNT_STAR = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")() + + private val AGE_ALIAS = Alias(UnresolvedAttribute("age"), "age")() + + private val RELATION_EMPLOYEES = UnresolvedRelation(Seq("employees")) + + private val T12_JOIN_CONDITION = + EqualTo( + UnresolvedAttribute("APPENDCOL_T1._row_number_"), + UnresolvedAttribute("APPENDCOL_T2._row_number_")) + + private val T12_COLUMNS_SEQ = + Seq( + UnresolvedAttribute("APPENDCOL_T1._row_number_"), + UnresolvedAttribute("APPENDCOL_T2._row_number_")) + + // @formatter:off + /** + * Expected: + 'Project [*] + +- 'DataFrameDropColumns ['T1._row_number_, 'T2._row_number_] + +- 'Join FullOuter, ('T1._row_number_ = 'T2._row_number_) + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#1, *] + : +- 'UnresolvedRelation [employees], [], false + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#5, *] + +- 'Aggregate ['age AS age#3], ['COUNT(*) AS count()#2, 'age AS age#3] + +- 'UnresolvedRelation [employees], [], false + */ + // @formatter:on + test("test AppendCol with NO transformation on main") { + val context = new CatalystPlanContext + val logicalPlan = planTransformer.visit( + plan(pplParser, "source=employees | APPENDCOL [stats count() by age];"), + context) + + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#7, *] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project(Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), RELATION_EMPLOYEES)) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + +- 'Aggregate ['age AS age#9], ['COUNT(*) AS count()#8, 'age AS age#10] + +- 'UnresolvedRelation [relation], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_EMPLOYEES))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // @formatter:off + /** + * 'Project [*] + * +- 'DataFrameDropColumns ['T1._row_number_, 'T2._row_number_] + * +- 'Join FullOuter, ('T1._row_number_ = 'T2._row_number_) + * :- 'SubqueryAlias T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + * : +- 'Project ['age, 'dept, 'salary] + * : +- 'UnresolvedRelation [relation], [], false + * +- 'SubqueryAlias T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#15, *] + * +- 'Aggregate ['age AS age#13], ['COUNT(*) AS count()#12, 'age AS age#13] + * +- 'UnresolvedRelation [relation], [], false + */ + // @formatter:on + test("test AppendCol with transformation on main-search") { + val context = new CatalystPlanContext + val logicalPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | FIELDS age, dept, salary | APPENDCOL [stats count() by age];"), + context) + + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + : +- 'Project ['age, 'dept, 'salary] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq( + UnresolvedAttribute("age"), + UnresolvedAttribute("dept"), + UnresolvedAttribute("salary")), + RELATION_EMPLOYEES))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + +- 'Aggregate ['age AS age#9], ['COUNT(*) AS count()#8, 'age AS age#10] + +- 'UnresolvedRelation [relation], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_EMPLOYEES))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // @formatter:off + /** + * 'Project [*] + * +- 'DataFrameDropColumns ['T1._row_number_, 'T2._row_number_] + * +- 'Join FullOuter, ('T1._row_number_ = 'T2._row_number_) + * :- 'SubqueryAlias T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#427, *] + * : +- 'Project ['age, 'dept, 'salary] + * : +- 'UnresolvedRelation [employees], [], false + * +- 'SubqueryAlias T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#432, *] + * +- 'DataFrameDropColumns ['m] + * +- 'Project [*, 1 AS m#430] + * +- 'Aggregate ['age AS age#429], ['COUNT(*) AS count()#428, 'age AS age#429] + * +- 'UnresolvedRelation [employees], [], false + */ + // @formatter:on + test("test AppendCol with chained sub-search") { + val context = new CatalystPlanContext + val logicalPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | FIELDS age, dept, salary | APPENDCOL [ stats count() by age | eval m = 1 | FIELDS -m ];"), + context) + + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + : +- 'Project ['age, 'dept, 'salary] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq( + UnresolvedAttribute("age"), + UnresolvedAttribute("dept"), + UnresolvedAttribute("salary")), + RELATION_EMPLOYEES))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#432, *] + +- 'DataFrameDropColumns ['m] + +- 'Project [*, 1 AS m#430] + +- 'Aggregate ['age AS age#429], ['COUNT(*) AS count()#428, 'age AS age#429] + +- 'UnresolvedRelation [employees], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + DataFrameDropColumns( + Seq(UnresolvedAttribute("m")), + Project( + Seq(UnresolvedStar(None), Alias(Literal(1), "m")()), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_EMPLOYEES))))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + // @formatter:off + /** + * == Parsed Logical Plan == + * 'Project [*] + * +- 'DataFrameDropColumns ['T1._row_number_, 'T2._row_number_] + * +- 'Join FullOuter, ('T1._row_number_ = 'T2._row_number_) + * :- 'SubqueryAlias T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#551, *] + * : +- 'DataFrameDropColumns ['T1._row_number_, 'T2._row_number_] + * : +- 'Join FullOuter, ('T1._row_number_ = 'T2._row_number_) + * : :- 'SubqueryAlias T1 + * : : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#544, *] + * : : +- 'Project ['name, 'age] + * : : +- 'UnresolvedRelation [employees], [], false + * : +- 'SubqueryAlias T2 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#549, *] + * : +- 'DataFrameDropColumns ['m] + * : +- 'Project [*, 1 AS m#547] + * : +- 'Aggregate ['age AS age#546], ['COUNT(*) AS count()#545, 'age AS age#546] + * : +- 'UnresolvedRelation [employees], [], false + * +- 'SubqueryAlias T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#553, *] + * +- 'Project ['dept] + * +- 'UnresolvedRelation [employees], [], false + */ + // @formatter:on + test("test multiple AppendCol clauses") { + val context = new CatalystPlanContext + val logicalPlan = planTransformer.visit( + plan( + pplParser, + "source=employees | FIELDS name, age | APPENDCOL [ stats count() by age | eval m = 1 | FIELDS -m ] | APPENDCOL [FIELDS dept];"), + context) + + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#544, *] + : +- 'Project ['name, 'age] + : +- 'UnresolvedRelation [employees], [], false + */ + val mainSearch = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + RELATION_EMPLOYEES))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#432, *] + +- 'DataFrameDropColumns ['m] + +- 'Project [*, 1 AS m#430] + +- 'Aggregate ['age AS age#429], ['COUNT(*) AS count()#428, 'age AS age#429] + +- 'UnresolvedRelation [employees], [], false + */ + val firstAppenCol = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + DataFrameDropColumns( + Seq(UnresolvedAttribute("m")), + Project( + Seq(UnresolvedStar(None), Alias(Literal(1), "m")()), + Aggregate(AGE_ALIAS :: Nil, Seq(COUNT_STAR, AGE_ALIAS), RELATION_EMPLOYEES))))) + + val joinWithFirstAppendCol = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(mainSearch, firstAppenCol, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE)))) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#553, *] + +- 'Project ['dept] + +- 'UnresolvedRelation [employees], [], false + */ + val secondAppendCol = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project(Seq(UnresolvedAttribute("dept")), RELATION_EMPLOYEES))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join( + joinWithFirstAppendCol, + secondAppendCol, + FullOuter, + Some(T12_JOIN_CONDITION), + JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test invalid override sub-search") { + val context = new CatalystPlanContext + val exception = intercept[IllegalStateException]( + planTransformer + .visit( + plan( + pplParser, + "source=relation | FIELDS name, age | APPENDCOL override=true [ where age > 10]"), + context)) + assert(exception.getMessage startsWith "Not Supported operation") + } + + + // @formatter:off + /** + * 'Project [*] + * +- 'DataFrameDropColumns ['T1._row_number_, 'T2._row_number_, 'T1.age] + * +- 'Join FullOuter, ('T1._row_number_ = 'T2._row_number_) + * :- 'SubqueryAlias T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#383, *] + * : +- 'UnresolvedRelation [employees], [], false + * +- 'SubqueryAlias T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#386, *] + * +- 'Aggregate ['COUNT(*) AS age#384] + * +- 'UnresolvedRelation [employees], [], false + */ + // @formatter:on + test("test override with Supported sub-search") { + val context = new CatalystPlanContext + val logicalPlan = planTransformer.visit( + plan(pplParser, "source=employees | APPENDCOL OVERRIDE=true [stats count() as age];"), + context) + + /* + :- 'SubqueryAlias T1 + : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#7, *] + : +- 'UnresolvedRelation [relation], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project(Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), RELATION_EMPLOYEES)) + + /* + +- 'SubqueryAlias T2 + +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, + specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#11, *] + +- 'Aggregate ['COUNT(*) AS age#8] + +- 'UnresolvedRelation [relation], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Aggregate( + Nil, + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "age")()), + RELATION_EMPLOYEES))) + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ :+ UnresolvedAttribute("APPENDCOL_T1.age"), + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala index 1d00b9484..4cde6c994 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala @@ -13,9 +13,10 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection.universe.Star import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NamedExpression, NullsFirst, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Cast, Coalesce, Descending, GreaterThan, Literal, NamedExpression, NullsFirst, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Project, Sort} +import org.apache.spark.sql.types.IntegerType class PPLLogicalPlanParseTranslatorTestSuite extends SparkFunSuite @@ -120,43 +121,49 @@ class PPLLogicalPlanParseTranslatorTestSuite assert(compareByString(expectedPlan) === compareByString(logPlan)) } - test("test parse email & host expressions including cast and sort commands") { + test("test parse street number & street expressions including cast and sort commands") { val context = new CatalystPlanContext - val logPlan = - planTransformer.visit( - plan( - pplParser, - "source=t | parse address '(?\\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street"), - context) + + // TODO #963: Implement 'num', 'str', and 'ip' sort syntax + val query = + "source=t" + + " | parse address '(?\\d+) (?.+)'" + + " | eval streetNumberInt = cast(streetNumber as integer)" + + " | where streetNumberInt > 500" + + " | sort streetNumberInt" + + " | fields streetNumber, street" + + val logPlan = planTransformer.visit(plan(pplParser, query), context) val addressAttribute = UnresolvedAttribute("address") val streetNumberAttribute = UnresolvedAttribute("streetNumber") val streetAttribute = UnresolvedAttribute("street") + val streetNumberIntAttribute = UnresolvedAttribute("streetNumberInt") - val streetNumberExpression = Alias( - RegExpExtract( - addressAttribute, - Literal("(?\\d+) (?.+)"), - Literal("1")), - "streetNumber")() + val regexLiteral = Literal("(?\\d+) (?.+)") + val streetNumberExpression = + Alias(RegExpExtract(addressAttribute, regexLiteral, Literal("1")), "streetNumber")() + val streetExpression = + Alias(RegExpExtract(addressAttribute, regexLiteral, Literal("2")), "street")() - val streetExpression = Alias( - RegExpExtract( - addressAttribute, - Literal("(?\\d+) (?.+)"), - Literal("2")), - "street")() + val castExpression = Cast(streetNumberAttribute, IntegerType) val expectedPlan = Project( Seq(streetNumberAttribute, streetAttribute), Sort( - Seq(SortOrder(streetNumberAttribute, Ascending, NullsFirst, Seq.empty)), + Seq(SortOrder(streetNumberIntAttribute, Ascending, NullsFirst, Seq.empty)), global = true, Filter( - GreaterThan(streetNumberAttribute, Literal(500)), + GreaterThan(streetNumberIntAttribute, Literal(500)), Project( - Seq(addressAttribute, streetNumberExpression, streetExpression, UnresolvedStar(None)), - UnresolvedRelation(Seq("t")))))) + Seq(UnresolvedStar(None), Alias(castExpression, "streetNumberInt")()), + Project( + Seq( + addressAttribute, + streetNumberExpression, + streetExpression, + UnresolvedStar(None)), + UnresolvedRelation(Seq("t"))))))) assert(compareByString(expectedPlan) === compareByString(logPlan)) }