From e4af3583793ff492701a8a2a1e5922383c01b3e0 Mon Sep 17 00:00:00 2001 From: zouyunhe <811-zouyunhe@users.noreply.git.sysop.bigo.sg> Date: Fri, 30 Aug 2024 15:47:01 +0800 Subject: [PATCH] support nested column pruning --- .../backendsapi/clickhouse/CHBackend.scala | 4 + .../hive/GlutenClickHouseHiveTableSuite.scala | 53 ++++ .../backendsapi/BackendSettingsApi.scala | 2 + .../heuristic/OffloadSingleNode.scala | 8 +- .../hive/HiveTableScanExecTransformer.scala | 11 +- .../HiveTableScanNestedColumnPruning.scala | 252 ++++++++++++++++++ .../org/apache/gluten/GlutenConfig.scala | 10 + .../execution/AbstractHiveTableScanExec.scala | 7 +- .../execution/AbstractHiveTableScanExec.scala | 7 +- .../execution/AbstractHiveTableScanExec.scala | 7 +- .../execution/AbstractHiveTableScanExec.scala | 7 +- 11 files changed, 360 insertions(+), 8 deletions(-) create mode 100644 gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanNestedColumnPruning.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 54ab38569bb8a..6e73ff6b29d1b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -396,4 +396,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } override def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = true + + override def supportHiveTableScanNestedColumnPruning: Boolean = + GlutenConfig.getConf.enableColumnarHiveTableScanNestedColumnPruning + } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala index ff2d13996dc6f..7e31e73040d49 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala @@ -1499,4 +1499,57 @@ class GlutenClickHouseHiveTableSuite spark.sql("drop table if exists aj") } + test("test hive table scan nested column pruning") { + val json_table_name = "test_tbl_7267_json" + val pq_table_name = "test_tbl_7267_pq" + val create_table_sql = + s""" + | create table if not exists %s( + | id bigint, + | d1 STRUCT>>, + | d2 STRUCT>>, + | day string, + | hour string + | ) partitioned by(day, hour) + |""".stripMargin + val create_table_json = create_table_sql.format(json_table_name) + + s""" + | ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + | STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' + |""".stripMargin + val create_table_pq = create_table_sql.format(pq_table_name) + " Stored as PARQUET" + val insert_sql = + """ + | insert into %s values(1, + | named_struct('c', 'c123', 'd', array(named_struct('x', 'x123', 'y', 'y123'))), + | named_struct('c', 'c124', 'd', map('m124', named_struct('x', 'x124', 'y', 'y124'))), + | '2024-09-26', '12' + | ) + |""".stripMargin + val select_sql = + "select id, d1.c, d1.d[0].x, d2.d['m124'].y from %s where day = '2024-09-26' and hour = '12'" + val table_names = Array.apply(json_table_name, pq_table_name) + val create_table_sqls = Array.apply(create_table_json, create_table_pq) + for (i <- table_names.indices) { + val table_name = table_names(i) + val create_table = create_table_sqls(i) + spark.sql(create_table) + spark.sql(insert_sql.format(table_name)) + withSQLConf(("spark.sql.hive.convertMetastoreParquet" -> "false")) { + compareResultsAgainstVanillaSpark( + select_sql.format(table_name), + compareResult = true, + df => { + val scan = collect(df.queryExecution.executedPlan) { + case l: HiveTableScanExecTransformer => l + } + assert(scan.size == 1) + } + ) + } + spark.sql("drop table if exists %s".format(table_name)) + } + } + } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index 177d19c0c7091..700571fd2815a 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -131,4 +131,6 @@ trait BackendSettingsApi { def supportColumnarArrowUdf(): Boolean = false def needPreComputeRangeFrameBoundary(): Boolean = false + + def supportHiveTableScanNestedColumnPruning(): Boolean = false } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/OffloadSingleNode.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/OffloadSingleNode.scala index a8c200e9be449..bae98bec2ec68 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/OffloadSingleNode.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/OffloadSingleNode.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec} import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim} -import org.apache.spark.sql.hive.HiveTableScanExecTransformer +import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveTableScanNestedColumnPruning} /** * Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is supposed to be executed @@ -226,7 +226,11 @@ object OffloadOthers { case plan: ProjectExec => val columnarChild = plan.child logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ProjectExecTransformer(plan.projectList, columnarChild) + if (HiveTableScanNestedColumnPruning.supportNestedColumnPruning(plan)) { + HiveTableScanNestedColumnPruning.apply(plan) + } else { + ProjectExecTransformer(plan.projectList, columnarChild) + } case plan: HashAggregateExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") HashAggregateExecBaseTransformer.from(plan) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala index 85432350d4a24..f701c76b18134 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala @@ -45,7 +45,8 @@ import java.net.URI case class HiveTableScanExecTransformer( requestedAttributes: Seq[Attribute], relation: HiveTableRelation, - partitionPruningPred: Seq[Expression])(@transient session: SparkSession) + partitionPruningPred: Seq[Expression], + prunedOutput: Seq[Attribute] = Seq.empty[Attribute])(@transient session: SparkSession) extends AbstractHiveTableScanExec(requestedAttributes, relation, partitionPruningPred)(session) with BasicScanExecTransformer { @@ -63,7 +64,13 @@ case class HiveTableScanExecTransformer( override def getMetadataColumns(): Seq[AttributeReference] = Seq.empty - override def outputAttributes(): Seq[Attribute] = output + override def outputAttributes(): Seq[Attribute] = { + if (prunedOutput.nonEmpty) { + prunedOutput + } else { + output + } + } override def getPartitions: Seq[InputPartition] = partitions diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanNestedColumnPruning.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanNestedColumnPruning.scala new file mode 100644 index 0000000000000..7a20e5c37da57 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanNestedColumnPruning.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.execution.ProjectExecTransformer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, SparkPlan} +import org.apache.spark.sql.hive.HiveTableScanExecTransformer.{ORC_INPUT_FORMAT_CLASS, PARQUET_INPUT_FORMAT_CLASS, TEXT_INPUT_FORMAT_CLASS} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.util.SchemaUtils._ +import org.apache.spark.util.Utils + +object HiveTableScanNestedColumnPruning extends Logging { + import org.apache.spark.sql.catalyst.expressions.SchemaPruning._ + + def supportNestedColumnPruning(projectExec: ProjectExec): Boolean = { + if (BackendsApiManager.getSettings.supportHiveTableScanNestedColumnPruning()) { + projectExec.child match { + case HiveTableScanExecTransformer(_, relation, _, _) => + relation.tableMeta.storage.inputFormat match { + case Some(inputFormat) + if TEXT_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) => + relation.tableMeta.storage.serde match { + case Some("org.openx.data.jsonserde.JsonSerDe") | Some( + "org.apache.hive.hcatalog.data.JsonSerDe") => + return true + case _ => + } + case Some(inputFormat) + if ORC_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) => + return true + case Some(inputFormat) + if PARQUET_INPUT_FORMAT_CLASS.isAssignableFrom(Utils.classForName(inputFormat)) => + return true + case _ => + } + case _ => + } + } + false + } + + def apply(plan: SparkPlan): SparkPlan = { + plan match { + case ProjectExec(projectList, child) => + child match { + case h: HiveTableScanExecTransformer => + val newPlan = prunePhysicalColumns( + h.relation, + projectList, + Seq.empty[Expression], + (prunedDataSchema, prunedMetadataSchema) => { + buildNewHiveTableScan(h, prunedDataSchema, prunedMetadataSchema) + }, + (schema, requestFields) => { + h.pruneSchema(schema, requestFields) + } + ) + if (newPlan.nonEmpty) { + return newPlan.get + } else { + return ProjectExecTransformer(projectList, child) + } + case _ => + return ProjectExecTransformer(projectList, child) + } + case _ => + } + plan + } + + private def prunePhysicalColumns( + relation: HiveTableRelation, + projects: Seq[NamedExpression], + filters: Seq[Expression], + leafNodeBuilder: (StructType, StructType) => LeafExecNode, + pruneSchemaFunc: (StructType, Seq[SchemaPruning.RootField]) => StructType) + : Option[SparkPlan] = { + val (normalizedProjects, normalizedFilters) = + normalizeAttributeRefNames(relation.output, projects, filters) + val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters) + // If requestedRootFields includes a nested field, continue. Otherwise, + // return op + if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { + val prunedDataSchema = pruneSchemaFunc(relation.tableMeta.dataSchema, requestedRootFields) + val metaFieldNames = relation.tableMeta.schema.fieldNames + val metadataSchema = relation.output.collect { + case attr: AttributeReference if metaFieldNames.contains(attr.name) => attr + }.toStructType + val prunedMetadataSchema = if (metadataSchema.nonEmpty) { + pruneSchemaFunc(metadataSchema, requestedRootFields) + } else { + metadataSchema + } + // If the data schema is different from the pruned data schema + // OR + // the metadata schema is different from the pruned metadata schema, continue. + // Otherwise, return None. + if ( + countLeaves(relation.tableMeta.dataSchema) > countLeaves(prunedDataSchema) || + countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema) + ) { + val leafNode = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema) + val projectionOverSchema = ProjectionOverSchema( + prunedDataSchema.merge(prunedMetadataSchema), + AttributeSet(relation.output)) + Some( + buildNewProjection( + projects, + normalizedProjects, + normalizedFilters, + leafNode, + projectionOverSchema)) + } else { + None + } + } else { + None + } + } + + /** + * Normalizes the names of the attribute references in the given projects and filters to reflect + * the names in the given logical relation. This makes it possible to compare attributes and + * fields by name. Returns a tuple with the normalized projects and filters, respectively. + */ + private def normalizeAttributeRefNames( + output: Seq[AttributeReference], + projects: Seq[NamedExpression], + filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = { + val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap + val normalizedProjects = projects + .map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }) + .map { case expr: NamedExpression => expr } + val normalizedFilters = filters.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }) + (normalizedProjects, normalizedFilters) + } + + /** Builds the new output [[Project]] Spark SQL operator that has the `leafNode`. */ + private def buildNewProjection( + projects: Seq[NamedExpression], + normalizedProjects: Seq[NamedExpression], + filters: Seq[Expression], + leafNode: LeafExecNode, + projectionOverSchema: ProjectionOverSchema): ProjectExecTransformer = { + // Construct a new target for our projection by rewriting and + // including the original filters where available + val projectionChild = + if (filters.nonEmpty) { + val projectedFilters = filters.map(_.transformDown { + case projectionOverSchema(expr) => expr + }) + val newFilterCondition = projectedFilters.reduce(And) + FilterExec(newFilterCondition, leafNode) + } else { + leafNode + } + + // Construct the new projections of our Project by + // rewriting the original projections + val newProjects = + normalizedProjects.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { + case expr: NamedExpression => expr + } + + ProjectExecTransformer( + restoreOriginalOutputNames(newProjects, projects.map(_.name)), + projectionChild) + } + + private def buildNewHiveTableScan( + hiveTableScan: HiveTableScanExecTransformer, + prunedDataSchema: StructType, + prunedMetadataSchema: StructType): HiveTableScanExecTransformer = { + val relation = hiveTableScan.relation + val partitionSchema = relation.tableMeta.partitionSchema + val prunedBaseSchema = StructType( + prunedDataSchema.fields.filterNot( + f => partitionSchema.fieldNames.contains(f.name)) ++ partitionSchema.fields) + val finalSchema = prunedBaseSchema.merge(prunedMetadataSchema) + val prunedOutput = getPrunedOutput(relation.output, finalSchema) + var finalOutput = Seq.empty[Attribute] + for (p <- hiveTableScan.output) { + var flag = false + for (q <- prunedOutput if !flag) { + if (p.name.equals(q.name)) { + finalOutput :+= q + flag = true + } + } + } + HiveTableScanExecTransformer( + hiveTableScan.requestedAttributes, + relation, + hiveTableScan.partitionPruningPred, + finalOutput)(hiveTableScan.session) + } + + // Prune the given output to make it consistent with `requiredSchema`. + private def getPrunedOutput( + output: Seq[AttributeReference], + requiredSchema: StructType): Seq[Attribute] = { + // We need to update the data type of the output attributes to use the pruned ones. + // so that references to the original relation's output are not broken + val nameAttributeMap = output.map(att => (att.name, att)).toMap + val requiredAttributes = + requiredSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + requiredAttributes.map { + case att if nameAttributeMap.contains(att.name) => + nameAttributeMap(att.name).withDataType(att.dataType) + case att => att + } + } + + /** + * Counts the "leaf" fields of the given dataType. Informally, this is the number of fields of + * non-complex data type in the tree representation of [[DataType]]. + */ + private def countLeaves(dataType: DataType): Int = { + dataType match { + case array: ArrayType => countLeaves(array.elementType) + case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType) + case struct: StructType => + struct.map(field => countLeaves(field.dataType)).sum + case _ => 1 + } + } +} diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 0e41fc65376ca..3d8be4daea9be 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -50,6 +50,9 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableColumnarHiveTableScan: Boolean = conf.getConf(COLUMNAR_HIVETABLESCAN_ENABLED) + def enableColumnarHiveTableScanNestedColumnPruning: Boolean = + conf.getConf(COLUMNAR_HIVETABLESCAN_NESTED_COLUMN_PRUNING_ENABLED) + def enableVanillaVectorizedReaders: Boolean = conf.getConf(VANILLA_VECTORIZED_READERS_ENABLED) def enableColumnarHashAgg: Boolean = conf.getConf(COLUMNAR_HASHAGG_ENABLED) @@ -859,6 +862,13 @@ object GlutenConfig { .booleanConf .createWithDefault(true) + val COLUMNAR_HIVETABLESCAN_NESTED_COLUMN_PRUNING_ENABLED = + buildConf("spark.gluten.sql.columnar.enableNestedColumnPruningInHiveTableScan") + .internal() + .doc("Enable or disable nested column pruning in hivetablescan.") + .booleanConf + .createWithDefault(true) + val VANILLA_VECTORIZED_READERS_ENABLED = buildStaticConf("spark.gluten.sql.columnar.enableVanillaVectorizedReaders") .internal() diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index 46b59ac306c21..f38c85a49ddea 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -22,12 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.types.{BooleanType, DataType, StructType} import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -232,4 +233,8 @@ abstract private[hive] class AbstractHiveTableScanExec( } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) + + def pruneSchema(schema: StructType, requestedFields: Seq[RootField]): StructType = { + SchemaPruning.pruneDataSchema(schema, requestedFields) + } } diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index dd095f0ff2472..d9b6bb936f673 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -22,12 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.types.{BooleanType, DataType, StructType} import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -239,4 +240,8 @@ abstract private[hive] class AbstractHiveTableScanExec( } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) + + def pruneSchema(schema: StructType, requestedFields: Seq[RootField]): StructType = { + SchemaPruning.pruneSchema(schema, requestedFields) + } } diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index 87aba00b0f593..3521d496546d2 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -22,12 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.types.{BooleanType, DataType, StructType} import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -257,4 +258,8 @@ abstract private[hive] class AbstractHiveTableScanExec( } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) + + def pruneSchema(schema: StructType, requestedFields: Seq[RootField]): StructType = { + SchemaPruning.pruneSchema(schema, requestedFields) + } } diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index 87aba00b0f593..3521d496546d2 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -22,12 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.types.{BooleanType, DataType, StructType} import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration @@ -257,4 +258,8 @@ abstract private[hive] class AbstractHiveTableScanExec( } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) + + def pruneSchema(schema: StructType, requestedFields: Seq[RootField]): StructType = { + SchemaPruning.pruneSchema(schema, requestedFields) + } }