Skip to content

Commit

Permalink
add tests with inner table tablesample(? percent)
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Oct 21, 2024
1 parent 1326858 commit 26a6599
Show file tree
Hide file tree
Showing 4 changed files with 488 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,46 @@ class FlintSparkPPLAggregationWithSpanITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
"create ppl average age by span of interval of 10 years group by country head (limit) 2 query test with tablesample(100 percent)") {
val frame = sql(s"""
| source = $testTable tablesample(100 percent)| stats avg(age) by span(age, 10) as age_span, country | head 3
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val countryField = UnresolvedAttribute("country")
val countryAlias = Alias(countryField, "country")()

val aggregateExpressions =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val aggregatePlan =
Aggregate(
Seq(countryAlias, span),
Seq(aggregateExpressions, countryAlias, span),
Sample(0, 1, withReplacement = false, 0, table))
val limitPlan = Limit(Literal(3), aggregatePlan)
val expectedPlan = Project(star, limitPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
"create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") {
val frame = sql(s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,32 @@ class FlintSparkPPLAggregationsITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age avg query test with tablesample(75 percent)") {
val frame = sql(s"""
| source = $testTable tablesample(75 percent)| stats avg(age)
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val aggregateExpressions =
Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")())
val aggregatePlan =
Aggregate(Seq(), aggregateExpressions, Sample(0, 0.75, withReplacement = false, 0, table))
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age avg query with filter test") {
val frame = sql(s"""
| source = $testTable| where age < 50 | stats avg(age)
Expand Down Expand Up @@ -161,6 +187,40 @@ class FlintSparkPPLAggregationsITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test(
"create ppl simple age avg group by country head (limit) query test with tablesample(75 percent) ") {
val frame = sql(s"""
| source = $testTable tablesample(75 percent) | stats avg(age) by country | head 1
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(
groupByAttributes,
Seq(aggregateExpressions, productAlias),
Sample(0, 0.75, withReplacement = false, 0, table))
val projectPlan = Limit(Literal(1), aggregatePlan)
val expectedPlan = Project(Seq(UnresolvedStar(None)), projectPlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl simple age max group by country query test ") {
val frame = sql(s"""
| source = $testTable| stats max(age) by country
Expand Down Expand Up @@ -343,6 +403,46 @@ class FlintSparkPPLAggregationsITSuite
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test(" count * query test with tablesample(50 percent) ") {
val frame = sql(s"""
| source = $testTable tablesample(75 percent) | stats count()
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(3L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(UnresolvedFunction(Seq("COUNT"), star, isDistinct = false), "count")()

val aggregatePlan =
Aggregate(
groupByAttributes,
Seq(aggregateExpressions),
Sample(0, 0.75, withReplacement = false, 0, table))
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl simple age avg group by country with state filter query test ") {
val frame = sql(s"""
| source = $testTable| where state != 'Quebec' | stats avg(age) by country
Expand Down Expand Up @@ -460,6 +560,57 @@ class FlintSparkPPLAggregationsITSuite
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test(
"create ppl age sample stddev group by country query test with sort with tablesample(75 percent)") {
val frame = sql(s"""
| source = $testTable tablesample(100 percent)| stats stddev_samp(age) by country | sort country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(3.5355339059327378d, "Canada"), Row(28.284271247461902d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(
groupByAttributes,
Seq(aggregateExpressions, productAlias),
Sample(0, 1, withReplacement = false, 0, table))
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age sample stddev group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats stddev_samp(age) by country
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.{And, Descending, EqualTo, InSubquery, ListQuery, Literal, Not, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, JoinHint, LogicalPlan, Project, Sample, Sort, SubqueryAlias}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLInSubqueryITSuite
Expand Down Expand Up @@ -126,6 +126,46 @@ class FlintSparkPPLInSubqueryITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test filter id in (select uid from inner) with outer table tablesample(100 percent)") {
val frame = sql(s"""
source = $outerTable tablesample(100 percent) | where (id) in [ source = $innerTable | fields uid ]
| | sort - salary
| | fields id, name, salary
| """.stripMargin)
val results: Set[Row] = frame.collect().toSet
val expectedResults: Set[Row] = Set(
Row(1003, "David", 120000),
Row(1002, "John", 120000),
Row(1000, "Jake", 100000),
Row(1005, "Jane", 90000),
Row(1006, "Tommy", 30000))
assert(
results == expectedResults,
s"The first two results do not match the expected rows. Expected: $expectedResults, Actual: $results")

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))
val inSubquery =
Filter(
InSubquery(
Seq(UnresolvedAttribute("id")),
ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner))),
Sample(0, 1, withReplacement = false, 0, outer))
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery)
val expectedPlan =
Project(
Seq(
UnresolvedAttribute("id"),
UnresolvedAttribute("name"),
UnresolvedAttribute("salary")),
sortedPlan)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test where (id) in (select uid from inner)") {
// id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6)
// InSubquery: (0, 2, 3, 5, 6)
Expand Down Expand Up @@ -170,6 +210,54 @@ class FlintSparkPPLInSubqueryITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test where (id) in (select uid from inner) with inner table tablesample(100 percent)") {
// id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6)
// InSubquery: (0, 2, 3, 5, 6)
val frame = sql(s"""
source = $outerTable
| | where (id) in [
| source = $innerTable tablesample(100 percent) | fields uid
| ]
| | sort - salary
| | fields id, name, salary
| """.stripMargin)
val results: Set[Row] = frame.collect().toSet
val expectedResults: Set[Row] = Set(
Row(1003, "David", 120000),
Row(1002, "John", 120000),
Row(1000, "Jake", 100000),
Row(1005, "Jane", 90000),
Row(1006, "Tommy", 30000))
assert(
results == expectedResults,
s"The first two results do not match the expected rows. Expected: $expectedResults, Actual: $results")

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))
val inSubquery =
Filter(
InSubquery(
Seq(UnresolvedAttribute("id")),
ListQuery(
Project(
Seq(UnresolvedAttribute("uid")),
Sample(0, 1, withReplacement = false, 0, inner)))),
outer)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery)
val expectedPlan =
Project(
Seq(
UnresolvedAttribute("id"),
UnresolvedAttribute("name"),
UnresolvedAttribute("salary")),
sortedPlan)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test where (id, name) in (select uid, name from inner)") {
// InSubquery: (0, 2, 3, 5)
val frame = sql(s"""
Expand Down Expand Up @@ -213,6 +301,53 @@ class FlintSparkPPLInSubqueryITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test(
"test where (id, name) in (select uid, name from inner) with both tables tablesample(100 percent)") {
// InSubquery: (0, 2, 3, 5)
val frame = sql(s"""
source = $outerTable tablesample(100 percent)
| | where (id, name) in [
| source = $innerTable tablesample(100 percent)| fields uid, name
| ]
| | sort - salary
| | fields id, name, salary
| """.stripMargin)
val results: Set[Row] = frame.collect().toSet
val expectedResults: Set[Row] = Set(
Row(1003, "David", 120000),
Row(1002, "John", 120000),
Row(1000, "Jake", 100000),
Row(1005, "Jane", 90000))
assert(
results == expectedResults,
s"The first two results do not match the expected rows. Expected: $expectedResults, Actual: $results")

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))
val inSubquery =
Filter(
InSubquery(
Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")),
ListQuery(
Project(
Seq(UnresolvedAttribute("uid"), UnresolvedAttribute("name")),
Sample(0, 1, withReplacement = false, 0, inner)))),
Sample(0, 1, withReplacement = false, 0, outer))
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery)
val expectedPlan =
Project(
Seq(
UnresolvedAttribute("id"),
UnresolvedAttribute("name"),
UnresolvedAttribute("salary")),
sortedPlan)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test where id not in (select uid from inner)") {
// id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6)
// Not InSubquery: (1, 4)
Expand Down
Loading

0 comments on commit 26a6599

Please sign in to comment.