Skip to content

Commit

Permalink
support nested column pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
zouyunhe authored and KevinyhZou committed Nov 14, 2024
1 parent 21b4e65 commit e4af358
Show file tree
Hide file tree
Showing 11 changed files with 360 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -396,4 +396,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
}

override def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = true

override def supportHiveTableScanNestedColumnPruning: Boolean =
GlutenConfig.getConf.enableColumnarHiveTableScanNestedColumnPruning

}
Original file line number Diff line number Diff line change
Expand Up @@ -1499,4 +1499,57 @@ class GlutenClickHouseHiveTableSuite
spark.sql("drop table if exists aj")
}

test("test hive table scan nested column pruning") {
val json_table_name = "test_tbl_7267_json"
val pq_table_name = "test_tbl_7267_pq"
val create_table_sql =
s"""
| create table if not exists %s(
| id bigint,
| d1 STRUCT<c: STRING, d: ARRAY<STRUCT<x: STRING, y: STRING>>>,
| d2 STRUCT<c: STRING, d: Map<STRING, STRUCT<x: STRING, y: STRING>>>,
| day string,
| hour string
| ) partitioned by(day, hour)
|""".stripMargin
val create_table_json = create_table_sql.format(json_table_name) +
s"""
| ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'
| STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'
| OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
|""".stripMargin
val create_table_pq = create_table_sql.format(pq_table_name) + " Stored as PARQUET"
val insert_sql =
"""
| insert into %s values(1,
| named_struct('c', 'c123', 'd', array(named_struct('x', 'x123', 'y', 'y123'))),
| named_struct('c', 'c124', 'd', map('m124', named_struct('x', 'x124', 'y', 'y124'))),
| '2024-09-26', '12'
| )
|""".stripMargin
val select_sql =
"select id, d1.c, d1.d[0].x, d2.d['m124'].y from %s where day = '2024-09-26' and hour = '12'"
val table_names = Array.apply(json_table_name, pq_table_name)
val create_table_sqls = Array.apply(create_table_json, create_table_pq)
for (i <- table_names.indices) {
val table_name = table_names(i)
val create_table = create_table_sqls(i)
spark.sql(create_table)
spark.sql(insert_sql.format(table_name))
withSQLConf(("spark.sql.hive.convertMetastoreParquet" -> "false")) {
compareResultsAgainstVanillaSpark(
select_sql.format(table_name),
compareResult = true,
df => {
val scan = collect(df.queryExecution.executedPlan) {
case l: HiveTableScanExecTransformer => l
}
assert(scan.size == 1)
}
)
}
spark.sql("drop table if exists %s".format(table_name))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,6 @@ trait BackendSettingsApi {
def supportColumnarArrowUdf(): Boolean = false

def needPreComputeRangeFrameBoundary(): Boolean = false

def supportHiveTableScanNestedColumnPruning(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec}
import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveTableScanNestedColumnPruning}

/**
* Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is supposed to be executed
Expand Down Expand Up @@ -226,7 +226,11 @@ object OffloadOthers {
case plan: ProjectExec =>
val columnarChild = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ProjectExecTransformer(plan.projectList, columnarChild)
if (HiveTableScanNestedColumnPruning.supportNestedColumnPruning(plan)) {
HiveTableScanNestedColumnPruning.apply(plan)
} else {
ProjectExecTransformer(plan.projectList, columnarChild)
}
case plan: HashAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ import java.net.URI
case class HiveTableScanExecTransformer(
requestedAttributes: Seq[Attribute],
relation: HiveTableRelation,
partitionPruningPred: Seq[Expression])(@transient session: SparkSession)
partitionPruningPred: Seq[Expression],
prunedOutput: Seq[Attribute] = Seq.empty[Attribute])(@transient session: SparkSession)
extends AbstractHiveTableScanExec(requestedAttributes, relation, partitionPruningPred)(session)
with BasicScanExecTransformer {

Expand All @@ -63,7 +64,13 @@ case class HiveTableScanExecTransformer(

override def getMetadataColumns(): Seq[AttributeReference] = Seq.empty

override def outputAttributes(): Seq[Attribute] = output
override def outputAttributes(): Seq[Attribute] = {
if (prunedOutput.nonEmpty) {
prunedOutput
} else {
output
}
}

override def getPartitions: Seq[InputPartition] = partitions

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.ProjectExecTransformer

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, SparkPlan}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer.{ORC_INPUT_FORMAT_CLASS, PARQUET_INPUT_FORMAT_CLASS, TEXT_INPUT_FORMAT_CLASS}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.Utils

object HiveTableScanNestedColumnPruning extends Logging {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._

def supportNestedColumnPruning(projectExec: ProjectExec): Boolean = {
if (BackendsApiManager.getSettings.supportHiveTableScanNestedColumnPruning()) {
projectExec.child match {
case HiveTableScanExecTransformer(_, relation, _, _) =>
relation.tableMeta.storage.inputFormat match {
case Some(inputFormat)
if TEXT_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) =>
relation.tableMeta.storage.serde match {
case Some("org.openx.data.jsonserde.JsonSerDe") | Some(
"org.apache.hive.hcatalog.data.JsonSerDe") =>
return true
case _ =>
}
case Some(inputFormat)
if ORC_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) =>
return true
case Some(inputFormat)
if PARQUET_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) =>
return true
case _ =>
}
case _ =>
}
}
false
}

def apply(plan: SparkPlan): SparkPlan = {
plan match {
case ProjectExec(projectList, child) =>
child match {
case h: HiveTableScanExecTransformer =>
val newPlan = prunePhysicalColumns(
h.relation,
projectList,
Seq.empty[Expression],
(prunedDataSchema, prunedMetadataSchema) => {
buildNewHiveTableScan(h, prunedDataSchema, prunedMetadataSchema)
},
(schema, requestFields) => {
h.pruneSchema(schema, requestFields)
}
)
if (newPlan.nonEmpty) {
return newPlan.get
} else {
return ProjectExecTransformer(projectList, child)
}
case _ =>
return ProjectExecTransformer(projectList, child)
}
case _ =>
}
plan
}

private def prunePhysicalColumns(
relation: HiveTableRelation,
projects: Seq[NamedExpression],
filters: Seq[Expression],
leafNodeBuilder: (StructType, StructType) => LeafExecNode,
pruneSchemaFunc: (StructType, Seq[SchemaPruning.RootField]) => StructType)
: Option[SparkPlan] = {
val (normalizedProjects, normalizedFilters) =
normalizeAttributeRefNames(relation.output, projects, filters)
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
// If requestedRootFields includes a nested field, continue. Otherwise,
// return op
if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
val prunedDataSchema = pruneSchemaFunc(relation.tableMeta.dataSchema, requestedRootFields)
val metaFieldNames = relation.tableMeta.schema.fieldNames
val metadataSchema = relation.output.collect {
case attr: AttributeReference if metaFieldNames.contains(attr.name) => attr
}.toStructType
val prunedMetadataSchema = if (metadataSchema.nonEmpty) {
pruneSchemaFunc(metadataSchema, requestedRootFields)
} else {
metadataSchema
}
// If the data schema is different from the pruned data schema
// OR
// the metadata schema is different from the pruned metadata schema, continue.
// Otherwise, return None.
if (
countLeaves(relation.tableMeta.dataSchema) > countLeaves(prunedDataSchema) ||
countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)
) {
val leafNode = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema)
val projectionOverSchema = ProjectionOverSchema(
prunedDataSchema.merge(prunedMetadataSchema),
AttributeSet(relation.output))
Some(
buildNewProjection(
projects,
normalizedProjects,
normalizedFilters,
leafNode,
projectionOverSchema))
} else {
None
}
} else {
None
}
}

/**
* Normalizes the names of the attribute references in the given projects and filters to reflect
* the names in the given logical relation. This makes it possible to compare attributes and
* fields by name. Returns a tuple with the normalized projects and filters, respectively.
*/
private def normalizeAttributeRefNames(
output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
val normalizedProjects = projects
.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
})
.map { case expr: NamedExpression => expr }
val normalizedFilters = filters.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
})
(normalizedProjects, normalizedFilters)
}

/** Builds the new output [[Project]] Spark SQL operator that has the `leafNode`. */
private def buildNewProjection(
projects: Seq[NamedExpression],
normalizedProjects: Seq[NamedExpression],
filters: Seq[Expression],
leafNode: LeafExecNode,
projectionOverSchema: ProjectionOverSchema): ProjectExecTransformer = {
// Construct a new target for our projection by rewriting and
// including the original filters where available
val projectionChild =
if (filters.nonEmpty) {
val projectedFilters = filters.map(_.transformDown {
case projectionOverSchema(expr) => expr
})
val newFilterCondition = projectedFilters.reduce(And)
FilterExec(newFilterCondition, leafNode)
} else {
leafNode
}

// Construct the new projections of our Project by
// rewriting the original projections
val newProjects =
normalizedProjects.map(_.transformDown { case projectionOverSchema(expr) => expr }).map {
case expr: NamedExpression => expr
}

ProjectExecTransformer(
restoreOriginalOutputNames(newProjects, projects.map(_.name)),
projectionChild)
}

private def buildNewHiveTableScan(
hiveTableScan: HiveTableScanExecTransformer,
prunedDataSchema: StructType,
prunedMetadataSchema: StructType): HiveTableScanExecTransformer = {
val relation = hiveTableScan.relation
val partitionSchema = relation.tableMeta.partitionSchema
val prunedBaseSchema = StructType(
prunedDataSchema.fields.filterNot(
f => partitionSchema.fieldNames.contains(f.name)) ++ partitionSchema.fields)
val finalSchema = prunedBaseSchema.merge(prunedMetadataSchema)
val prunedOutput = getPrunedOutput(relation.output, finalSchema)
var finalOutput = Seq.empty[Attribute]
for (p <- hiveTableScan.output) {
var flag = false
for (q <- prunedOutput if !flag) {
if (p.name.equals(q.name)) {
finalOutput :+= q
flag = true
}
}
}
HiveTableScanExecTransformer(
hiveTableScan.requestedAttributes,
relation,
hiveTableScan.partitionPruningPred,
finalOutput)(hiveTableScan.session)
}

// Prune the given output to make it consistent with `requiredSchema`.
private def getPrunedOutput(
output: Seq[AttributeReference],
requiredSchema: StructType): Seq[Attribute] = {
// We need to update the data type of the output attributes to use the pruned ones.
// so that references to the original relation's output are not broken
val nameAttributeMap = output.map(att => (att.name, att)).toMap
val requiredAttributes =
requiredSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
requiredAttributes.map {
case att if nameAttributeMap.contains(att.name) =>
nameAttributeMap(att.name).withDataType(att.dataType)
case att => att
}
}

/**
* Counts the "leaf" fields of the given dataType. Informally, this is the number of fields of
* non-complex data type in the tree representation of [[DataType]].
*/
private def countLeaves(dataType: DataType): Int = {
dataType match {
case array: ArrayType => countLeaves(array.elementType)
case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType)
case struct: StructType =>
struct.map(field => countLeaves(field.dataType)).sum
case _ => 1
}
}
}
10 changes: 10 additions & 0 deletions shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class GlutenConfig(conf: SQLConf) extends Logging {

def enableColumnarHiveTableScan: Boolean = conf.getConf(COLUMNAR_HIVETABLESCAN_ENABLED)

def enableColumnarHiveTableScanNestedColumnPruning: Boolean =
conf.getConf(COLUMNAR_HIVETABLESCAN_NESTED_COLUMN_PRUNING_ENABLED)

def enableVanillaVectorizedReaders: Boolean = conf.getConf(VANILLA_VECTORIZED_READERS_ENABLED)

def enableColumnarHashAgg: Boolean = conf.getConf(COLUMNAR_HASHAGG_ENABLED)
Expand Down Expand Up @@ -859,6 +862,13 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)

val COLUMNAR_HIVETABLESCAN_NESTED_COLUMN_PRUNING_ENABLED =
buildConf("spark.gluten.sql.columnar.enableNestedColumnPruningInHiveTableScan")
.internal()
.doc("Enable or disable nested column pruning in hivetablescan.")
.booleanConf
.createWithDefault(true)

val VANILLA_VECTORIZED_READERS_ENABLED =
buildStaticConf("spark.gluten.sql.columnar.enableVanillaVectorizedReaders")
.internal()
Expand Down
Loading

0 comments on commit e4af358

Please sign in to comment.