diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala index cfd6c2b3abd..a031a2aaeed 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package com.nvidia.spark.rapids -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.rapids.execution.TrampolineUtil +import org.apache.spark.sql.types._ object DataTypeUtils { def isNestedType(dataType: DataType): Boolean = dataType match { @@ -26,4 +27,15 @@ object DataTypeUtils { def hasNestedTypes(schema: StructType): Boolean = schema.exists(f => isNestedType(f.dataType)) + + /** + * If `t` is date/timestamp type or its children have a date/timestamp type. + * + * @param t input date type. + * @return if contains date type. + */ + def hasDateOrTimestampType(t: DataType): Boolean = { + TrampolineUtil.dataTypeExistsRecursively(t, e => + e.isInstanceOf[DateType] || e.isInstanceOf[TimestampType]) + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index fcc6c20a42c..3a2d68aa5d0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -774,9 +774,11 @@ private case class GpuParquetFileFilterHandler( val clipped = GpuParquetUtils.clipBlocksToSchema(clippedSchema, blocks, isCaseSensitive) (clipped, clippedSchema) } - + val hasDateTimeInReadSchema = DataTypeUtils.hasDateOrTimestampType(readDataSchema) val dateRebaseModeForThisFile = DateTimeRebaseUtils.datetimeRebaseMode( - footer.getFileMetaData.getKeyValueMetaData.get, datetimeRebaseMode) + footer.getFileMetaData.getKeyValueMetaData.get, + datetimeRebaseMode, + hasDateTimeInReadSchema) val hasInt96Timestamps = isParquetTimeInInt96(fileSchema) val timestampRebaseModeForThisFile = if (hasInt96Timestamps) { DateTimeRebaseUtils.int96RebaseMode( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala index ebcee60b0fb..c7e947d7afb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,7 +72,8 @@ object DateTimeRebaseUtils { private def rebaseModeFromFileMeta(lookupFileMeta: String => String, modeByConfig: String, minVersion: String, - metadataKey: String): DateTimeRebaseMode = { + metadataKey: String, + hasDateTimeInReadSchema: Boolean = true): DateTimeRebaseMode = { // If there is no version, we return the mode specified by the config. val mode = Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => @@ -95,7 +96,7 @@ object DateTimeRebaseUtils { // Use the default JVM time zone for backward compatibility TimeZone.getDefault.toZoneId } - if (fileTimeZoneId.normalized() != GpuOverrides.UTC_TIMEZONE_ID) { + if (hasDateTimeInReadSchema && fileTimeZoneId.normalized() != GpuOverrides.UTC_TIMEZONE_ID) { throw new UnsupportedOperationException( "LEGACY datetime rebase mode is only supported for files written in UTC timezone. " + s"Actual file timezone: $fileTimeZoneId") @@ -106,9 +107,10 @@ object DateTimeRebaseUtils { } def datetimeRebaseMode(lookupFileMeta: String => String, - modeByConfig: String): DateTimeRebaseMode = { + modeByConfig: String, + hasDateTimeInReadSchema: Boolean = true): DateTimeRebaseMode = { rebaseModeFromFileMeta(lookupFileMeta, modeByConfig, "3.0.0", - SPARK_LEGACY_DATETIME_METADATA_KEY) + SPARK_LEGACY_DATETIME_METADATA_KEY, hasDateTimeInReadSchema) } def int96RebaseMode(lookupFileMeta: String => String,