Skip to content

Commit

Permalink
fix nested fields query and filter issues as specified in
Browse files Browse the repository at this point in the history
  • Loading branch information
YANG-DB committed Aug 13, 2024
1 parent 21ec4cd commit 1e6dcd6
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
import org.antlr.v4.runtime.tree.TerminalNodeImpl
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
Expand All @@ -41,6 +40,7 @@ import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.{DataType, StructType}
import org.opensearch.flint.core.logging.CustomLogging.{logError, logInfo}

/**
* Flint SQL parser that extends Spark SQL parser with Flint SQL statements.
Expand All @@ -61,7 +61,12 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface
}
} catch {
// Fall back to Spark parse plan logic if flint cannot parse
case _: ParseException => sparkParser.parsePlan(sqlText)
case e: ParseException =>
// Log the issue
logInfo(s"Failed to parse PPL with PPL parser. Falling back to Spark parser. PPL: $sqlText", e)
// Fall back to Spark parse plan logic
sparkParser.parsePlan(sqlText)

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| CREATE TABLE $testTable
| (
| int_col INT,
| struct_col STRUCT<field1: STRUCT<subfield:STRING>, field2: INT>
| struct_col STRUCT<field1: STRUCT<subfield:STRING>, field2: INT>,
| struct_col2 STRUCT<field1: STRUCT<subfield:STRING>, field2: INT>
| )
| USING JSON
|""".stripMargin)
Expand All @@ -405,14 +406,14 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| INSERT INTO $testTable
| SELECT /*+ COALESCE(1) */ *
| FROM VALUES
| ( 30, STRUCT(STRUCT("value1"),123) ),
| ( 40, STRUCT(STRUCT("value5"),123) ),
| ( 30, STRUCT(STRUCT("value4"),823) ),
| ( 40, STRUCT(STRUCT("value2"),456) )
| ( 30, STRUCT(STRUCT("value1"),123), STRUCT(STRUCT("valueA"),23) ),
| ( 40, STRUCT(STRUCT("value5"),123), STRUCT(STRUCT("valueB"),33) ),
| ( 30, STRUCT(STRUCT("value4"),823), STRUCT(STRUCT("valueC"),83) ),
| ( 40, STRUCT(STRUCT("value2"),456), STRUCT(STRUCT("valueD"),46) )
|""".stripMargin)
sql(s"""
| INSERT INTO $testTable
| VALUES ( 50, STRUCT(STRUCT("value3"),789) )
| VALUES ( 50, STRUCT(STRUCT("value3"),789), STRUCT(STRUCT("valueE"),89) )
|""".stripMargin)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{And, Ascending, Descending, EqualTo, GreaterThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.streaming.StreamTest
Expand Down Expand Up @@ -49,7 +49,8 @@ class FlintSparkPPLNestedFieldsITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("int_col", "int", null),
Row("struct_col", "struct<field1:struct<subfield:string>,field2:int>", null))
Row("struct_col", "struct<field1:struct<subfield:string>,field2:int>", null),
Row("struct_col2", "struct<field1:struct<subfield:string>,field2:int>", null))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
Expand Down Expand Up @@ -80,11 +81,11 @@ class FlintSparkPPLNestedFieldsITSuite
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row(30, Row(Row("value1"), 123)),
Row(40, Row(Row("value5"), 123)),
Row(30, Row(Row("value4"), 823)),
Row(40, Row(Row("value2"), 456)),
Row(50, Row(Row("value3"), 789)))
Row(30, Row(Row("value1"), 123), Row(Row("valueA"), 23)),
Row(40, Row(Row("value5"), 123), Row(Row("valueB"), 33)),
Row(30, Row(Row("value4"), 823), Row(Row("valueC"), 83)),
Row(40, Row(Row("value2"), 456), Row(Row("valueD"), 46)),
Row(50, Row(Row("value3"), 789), Row(Row("valueE"), 89)))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))
Expand Down Expand Up @@ -175,120 +176,222 @@ class FlintSparkPPLNestedFieldsITSuite

test("create ppl simple query two with fields result test") {
val frame = sql(s"""
| source = $testTable| fields int_col, struct_col.field2
| source = $testTable| fields int_col, struct_col.field2, struct_col2.field2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row( 30, 123),
Row( 30, 823),
Row( 40, 123),
Row( 40, 456),
Row( 50, 789))
Array(Row( 30, 123, 23),
Row( 30, 823, 83),
Row( 40, 123, 33),
Row( 40, 456, 46),
Row( 50, 789, 89))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan = Project(
Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col.field2")),
Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col.field2"), UnresolvedAttribute("struct_col2.field2")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple sorted query two with fields result test sorted") {
val frame = sql(s"""
| source = $testTable| sort int_col | fields int_col, struct_col.field2
| source = $testTable| sort - struct_col2.field2 | fields int_col, struct_col2.field2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(30, 123),
Row(30, 823),
Row(40, 123),
Row(40, 456),
Row(50, 789))
Array(Row(50, 89),
Row(30, 83),
Row(40, 46),
Row(40, 33),
Row(30, 23))
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)),
Seq(SortOrder(UnresolvedAttribute("struct_col2.field2"), Descending)),
global = true,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))

// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan)
Project(Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col2.field2")), sortedPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query with nested field range filter test") {
test("create ppl simple sorted by nested field query with two with fields result test ") {
val frame = sql(s"""
| source = $testTable| where struct_col.field2 > 200 | fields int_col, struct_col.field2
| source = $testTable| sort - struct_col.field2 , - int_col | fields int_col, struct_col.field2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(30, 823),
Row(40, 456),
Row(50, 789))
assert(results === expectedResults)
Row(50, 789),
Row(40, 456),
Row(40, 123),
Row(30, 123))
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)),
Seq(SortOrder(UnresolvedAttribute("struct_col.field2"), Descending),
SortOrder(UnresolvedAttribute("int_col"), Descending)),
global = true,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))

// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan)
Project(Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col.field2")), sortedPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query with nested field 1 range filter test") {
val frame = sql(s"""
| source = $testTable| where struct_col.field2 > 200 | sort - struct_col.field2 | fields int_col, struct_col.field2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(30, 823),
Row(50, 789),
Row(40, 456))
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
// Define the expected logical plan components
val filterPlan = Filter(
GreaterThan(UnresolvedAttribute("struct_col.field2"), Literal(200)),
table)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("struct_col.field2"), Descending)), global = true, filterPlan)
val expectedPlan =
Project(Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col.field2")), sortedPlan)

test("create ppl simple query with nested field string filter test") {
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query with nested field 2 range filter test") {
val frame = sql(s"""
| source = $testTable| where struct_col.field1.subfield = `value1` | fields int_col, struct_col.field1.subfield
| source = $testTable| where struct_col2.field2 > 50 | sort - struct_col2.field2 | fields int_col, struct_col2.field2
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(30, "value1"))
Array(Row(50, 89),
Row(30, 83))
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
// Define the expected logical plan components
val filterPlan = Filter(
GreaterThan(UnresolvedAttribute("struct_col2.field2"), Literal(50)),
table)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("struct_col2.field2"), Descending)), global = true, filterPlan)
val expectedPlan =
Project(Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col2.field2")), sortedPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query with nested field string match test") {
val frame = sql(s"""
| source = $testTable| where struct_col.field1.subfield = 'value1' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(30, "value1", "valueA"))
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical


// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
// Define the expected logical plan components
val filterPlan = Filter(
EqualTo(UnresolvedAttribute("struct_col.field1.subfield"), Literal("value1")),
table)
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("age"), Ascending)),
global = true,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
Sort(Seq(SortOrder(UnresolvedAttribute("int_col"), Ascending)), global = true, filterPlan)
val expectedPlan =
Project(Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col.field1.subfield"), UnresolvedAttribute("struct_col2.field1.subfield")), sortedPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple query with nested field string filter test") {
val frame = sql(s"""
| source = $testTable| where struct_col2.field1.subfield > 'valueA' | sort int_col | fields int_col, struct_col.field1.subfield, struct_col2.field1.subfield
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(30, "value4", "valueC"),
Row(40, "value5", "valueB"),
Row(40, "value2", "valueD"),
Row(50, "value3", "valueE"))
assert(results === expectedResults)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical


// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan)
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
// Define the expected logical plan components
val filterPlan = Filter(
GreaterThan(UnresolvedAttribute("struct_col2.field1.subfield"), Literal("valueA")),
table)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("int_col"), Ascending)), global = true, filterPlan)
val expectedPlan =
Project(Seq(UnresolvedAttribute("int_col"), UnresolvedAttribute("struct_col.field1.subfield"), UnresolvedAttribute("struct_col2.field1.subfield")), sortedPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public interface RelationUtils {
* @return
*/
static Optional<QualifiedName> resolveField(List<UnresolvedRelation> relations, QualifiedName node) {
if (relations.size() == 1) return Optional.of(node);
return relations.stream()
.map(rel -> {
//if name doesnt contain table prefix - add the current relation prefix to the fields name - returns true
Expand Down

0 comments on commit 1e6dcd6

Please sign in to comment.