Skip to content

Commit

Permalink
update scala fmt style
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>

Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Aug 15, 2024
1 parent db2cfe0 commit 5ec92be
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}

class FlintSparkPPLTopAndRareITSuite
extends QueryTest
Expand All @@ -37,7 +37,7 @@ class FlintSparkPPLTopAndRareITSuite
job.awaitTermination()
}
}

test("create ppl rare address field query test") {
val frame = sql(s"""
| source = $testTable| rare address
Expand All @@ -46,20 +46,28 @@ class FlintSparkPPLTopAndRareITSuite
// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRow = Row(1, "Vancouver")
assert(results.head == expectedRow, s"Expected least frequent result to be $expectedRow, but got ${results.head}")

assert(
results.head == expectedRow,
s"Expected least frequent result to be $expectedRow, but got ${results.head}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val addressField = UnresolvedAttribute("address")
val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), "count(address)")(), addressField)
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")(),
addressField)
val aggregatePlan =
Aggregate(Seq(addressField), aggregateExpressions, UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
Aggregate(
Seq(addressField),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Descending)),
Expand All @@ -68,7 +76,7 @@ class FlintSparkPPLTopAndRareITSuite
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl top address field query test") {
val frame = sql(s"""
| source = $testTable| top address
Expand All @@ -77,15 +85,13 @@ class FlintSparkPPLTopAndRareITSuite
// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRows = Set(
Row(2, "Portland"),
Row(2, "Seattle")
)

val expectedRows = Set(Row(2, "Portland"), Row(2, "Seattle"))
val actualRows = results.take(2).toSet

// Compare the sets
assert(actualRows == expectedRows,
assert(
actualRows == expectedRows,
s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows")

// Retrieve the logical plan
Expand All @@ -95,9 +101,15 @@ class FlintSparkPPLTopAndRareITSuite
val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), "count(address)")(), addressField)
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")(),
addressField)
val aggregatePlan =
Aggregate(Seq(addressField), aggregateExpressions, UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
Aggregate(
Seq(addressField),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Ascending)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,47 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite

comparePlans(expectedPlan, logPlan, false)
}

test("test count price") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source = table | stats count(price) ", false), context)
planTransformer.visit(
plan(pplParser, "source = table | stats count(price) ", false),
context)
// SQL: SELECT avg(price) as avg_price FROM table
val star = Seq(UnresolvedStar(None))

val priceField = UnresolvedAttribute("price")
val tableRelation = UnresolvedRelation(Seq("table"))
val aggregateExpressions = Seq(
Alias(UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false), "count(price)")())
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false),
"count(price)")())
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

comparePlans(expectedPlan, logPlan, false)
}

test("test count price by country") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source = table | stats count(price) by product ", false), context)
planTransformer.visit(
plan(pplParser, "source = table | stats count(price) by product ", false),
context)
// SQL: SELECT count(price) AS count_price FROM table GROUP BY product
val star = Seq(UnresolvedStar(None))
val productField = UnresolvedAttribute("product")
val priceField = UnresolvedAttribute("price")
val tableRelation = UnresolvedRelation(Seq("table"))

val groupByAttributes = Seq(Alias(productField, "product")())
val aggregateExpressions =
Alias(UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false), "count(price)")()
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(priceField), isDistinct = false),
"count(price)")()
val productAlias = Alias(productField, "product")()

val aggregatePlan =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

package org.opensearch.flint.spark.ppl

import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite
extends SparkFunSuite
Expand All @@ -24,20 +25,22 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite

private val planTransformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()

test("test simple rare command with a single field") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=accounts | rare gender", false), context)
val logPlan =
planTransformer.visit(plan(pplParser, "source=accounts | rare gender", false), context)
val genderField = UnresolvedAttribute("gender")
val tableRelation = UnresolvedRelation(Seq("accounts"))

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(UnresolvedFunction(Seq("COUNT"), Seq(genderField), isDistinct = false), "count(gender)")(),
genderField
)
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(genderField), isDistinct = false),
"count(gender)")(),
genderField)

val aggregatePlan =
Aggregate(Seq(genderField), aggregateExpressions, tableRelation)
Expand All @@ -50,29 +53,28 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple top command with a single field") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=accounts | top gender", false), context)
val logPlan =
planTransformer.visit(plan(pplParser, "source=accounts | top gender", false), context)
val genderField = UnresolvedAttribute("gender")
val tableRelation = UnresolvedRelation(Seq("accounts"))

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(UnresolvedFunction(Seq("COUNT"), Seq(genderField), isDistinct = false), "count(gender)")(),
genderField
)
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(genderField), isDistinct = false),
"count(gender)")(),
genderField)

val aggregatePlan =
Aggregate(Seq(genderField), aggregateExpressions, tableRelation)

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("gender"), Ascending)),
global = true,
aggregatePlan)
Sort(Seq(SortOrder(UnresolvedAttribute("gender"), Ascending)), global = true, aggregatePlan)
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logPlan, false)
}
Expand Down

0 comments on commit 5ec92be

Please sign in to comment.