diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index bf312e2dd81..74f1708aa84 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -310,11 +310,11 @@ def test_parquet_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1 lambda spark: rf(spark).select(f.col('a') >= s0), conf=all_confs) -parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'] +parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'] # Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed delete this test and merge it -# into test_ts_read_round_trip nested timestamps and dates are not supported right now. +# into test_parquet_read_roundtrip_datetime @pytest.mark.parametrize('gen', [ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))), ArrayGen(ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))))], ids=idfn) @pytest.mark.parametrize('ts_write', parquet_ts_write_options) @@ -334,50 +334,36 @@ def test_parquet_ts_read_round_trip_nested(gen, spark_tmp_path, ts_write, ts_reb lambda spark : spark.read.parquet(data_path), conf=all_confs) -# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with -# timestamp_gen -@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], ids=idfn) -@pytest.mark.parametrize('ts_write', parquet_ts_write_options) -@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY']) -@pytest.mark.parametrize('reader_confs', reader_opt_confs) -@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -def test_ts_read_round_trip(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs): - data_path = spark_tmp_path + '/PARQUET_DATA' - with_cpu_session( - lambda spark : unary_op_df(spark, gen).write.parquet(data_path), - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, - 'spark.sql.parquet.outputTimestampType': ts_write}) - all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) - assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path), - conf=all_confs) - -def readParquetCatchException(spark, data_path): - with pytest.raises(Exception) as e_info: - df = spark.read.parquet(data_path).collect() - assert e_info.match(r".*SparkUpgradeException.*") +parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + string_gen, boolean_gen, date_gen, timestamp_gen]] -# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed nested timestamps and dates should be added in -# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with -# timestamp_gen -@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))], ids=idfn) -@pytest.mark.parametrize('ts_write', parquet_ts_write_options) -@pytest.mark.parametrize('ts_rebase', ['LEGACY']) +@pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn) +@pytest.mark.parametrize('ts_type', parquet_ts_write_options) +@pytest.mark.parametrize('ts_rebase_write', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')]) +@pytest.mark.parametrize('ts_rebase_read', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')]) @pytest.mark.parametrize('reader_confs', reader_opt_confs) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -def test_ts_read_fails_datetime_legacy(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs): +def test_parquet_read_roundtrip_datetime(spark_tmp_path, parquet_gens, ts_type, + ts_rebase_write, ts_rebase_read, + reader_confs, v1_enabled_list): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' + write_confs = {'spark.sql.parquet.outputTimestampType': ts_type, + 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write[0], + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write[1]} with_cpu_session( - lambda spark : unary_op_df(spark, gen).write.parquet(data_path), - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, - 'spark.sql.parquet.outputTimestampType': ts_write}) - all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) - with_gpu_session( - lambda spark : readParquetCatchException(spark, data_path), - conf=all_confs) + lambda spark: gen_df(spark, gen_list).write.parquet(data_path), + conf=write_confs) + # The rebase modes in read configs should be ignored and overridden by the same modes in write + # configs, which are retrieved from the written files. + read_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list, + 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0], + 'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read[1]}) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: spark.read.parquet(data_path), + conf=read_confs) +# This is legacy format, which is totally different from datatime legacy rebase mode. @pytest.mark.parametrize('parquet_gens', [[byte_gen, short_gen, decimal_gen_32bit], decimal_gens, [ArrayGen(decimal_gen_32bit, max_length=10)], [StructGen([['child0', decimal_gen_32bit]])]], ids=idfn) @@ -388,32 +374,11 @@ def test_parquet_decimal_read_legacy(spark_tmp_path, parquet_gens, read_func, re gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' with_cpu_session( - lambda spark : gen_df(spark, gen_list).write.parquet(data_path), - conf={'spark.sql.parquet.writeLegacyFormat': 'true'}) + lambda spark : gen_df(spark, gen_list).write.parquet(data_path), + conf={'spark.sql.parquet.writeLegacyFormat': 'true'}) all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) assert_gpu_and_cpu_are_equal_collect(read_func(data_path), conf=all_confs) - -parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens, - pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')), - pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))] - -@pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn) -@pytest.mark.parametrize('reader_confs', reader_opt_confs) -@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -def test_parquet_read_round_trip_legacy(spark_tmp_path, parquet_gens, v1_enabled_list, reader_confs): - gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] - data_path = spark_tmp_path + '/PARQUET_DATA' - with_cpu_session( - lambda spark : gen_df(spark, gen_list).write.parquet(data_path), - conf=rebase_write_legacy_conf) - all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) - assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path), - conf=all_confs) - @pytest.mark.parametrize('reader_confs', reader_opt_confs) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) @pytest.mark.parametrize('batch_size', [100, INT_MAX]) @@ -1004,7 +969,7 @@ def test_parquet_reading_from_unaligned_pages_basic_filters_with_nulls(spark_tmp conf_for_parquet_aggregate_pushdown = { - "spark.sql.parquet.aggregatePushdown": "true", + "spark.sql.parquet.aggregatePushdown": "true", "spark.sql.sources.useV1SourceList": "" } @@ -1491,7 +1456,7 @@ def test_parquet_read_count(spark_tmp_path): def test_read_case_col_name(spark_tmp_path, read_func, v1_enabled_list, reader_confs, col_name): all_confs = copy_and_update(reader_confs, { 'spark.sql.sources.useV1SourceList': v1_enabled_list}) - gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)), + gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)), ('k1', LongGen(nullable=False, min_val=1, max_val=1)), ('k2', LongGen(nullable=False, min_val=2, max_val=2)), ('k3', LongGen(nullable=False, min_val=3, max_val=3)), @@ -1499,7 +1464,7 @@ def test_read_case_col_name(spark_tmp_path, read_func, v1_enabled_list, reader_c ('v1', LongGen()), ('v2', LongGen()), ('v3', LongGen())] - + gen = StructGen(gen_list, nullable=False) data_path = spark_tmp_path + '/PAR_DATA' reader = read_func(data_path) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index ac2734a9a9c..c2249beb3f8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -37,7 +37,7 @@ import com.nvidia.spark.rapids.ParquetPartitionReader.{CopyRange, LocalCopy} import com.nvidia.spark.rapids.RapidsConf.ParquetFooterReaderType import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.filecache.FileCache -import com.nvidia.spark.rapids.jni.{ParquetFooter, SplitAndRetryOOM} +import com.nvidia.spark.rapids.jni.{DateTimeRebase, ParquetFooter, SplitAndRetryOOM} import com.nvidia.spark.rapids.shims.{GpuParquetCrypto, GpuTypeShims, ParquetLegacyNanoAsLongShims, ParquetSchemaClipShims, ParquetStringPredShims, ReaderUtils, ShimFilePartitionReaderFactory, SparkShimImpl} import org.apache.commons.io.IOUtils import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream} @@ -156,24 +156,7 @@ object GpuParquetScan { tagSupport(scan.sparkSession, schema, scanMeta) } - def throwIfRebaseNeeded(table: Table, dateRebaseMode: DateTimeRebaseMode, - timestampRebaseMode: DateTimeRebaseMode): Unit = { - (0 until table.getNumberOfColumns).foreach { i => - val col = table.getColumn(i) - if (dateRebaseMode != DateTimeRebaseCorrected && - DateTimeRebaseUtils.isDateRebaseNeededInRead(col)) { - throw DataSourceUtils.newRebaseExceptionInRead("Parquet") - } - else if (timestampRebaseMode != DateTimeRebaseCorrected && - DateTimeRebaseUtils.isTimeRebaseNeededInRead(col)) { - throw DataSourceUtils.newRebaseExceptionInRead("Parquet") - } - } - } - - def tagSupport( - sparkSession: SparkSession, - readSchema: StructType, + def tagSupport(sparkSession: SparkSession, readSchema: StructType, meta: RapidsMeta[_, _, _]): Unit = { val sqlConf = sparkSession.conf @@ -196,10 +179,14 @@ object GpuParquetScan { val schemaHasTimestamps = readSchema.exists { field => TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) } - def isTsOrDate(dt: DataType) : Boolean = dt match { + + def isTsOrDate(dt: DataType): Boolean = dt match { case TimestampType | DateType => true + // Timestamp without timezone (TimestampNTZType, since Spark 3.4) is not yet supported + // See https://github.com/NVIDIA/spark-rapids/issues/9707. case _ => false } + val schemaMightNeedNestedRebase = readSchema.exists { field => if (DataTypeUtils.isNestedType(field.dataType)) { TrampolineUtil.dataTypeExistsRecursively(field.dataType, isTsOrDate) @@ -316,17 +303,90 @@ object GpuParquetScan { * @return the updated target batch size. */ def splitTargetBatchSize(targetBatchSize: Long, useChunkedReader: Boolean): Long = { - if (!useChunkedReader) { + if (!useChunkedReader) { throw new SplitAndRetryOOM("GPU OutOfMemory: could not split inputs " + - "chunked parquet reader is configured off") + "chunked parquet reader is configured off") } val ret = targetBatchSize / 2 if (targetBatchSize < minTargetBatchSizeMiB * 1024 * 1024) { - throw new SplitAndRetryOOM("GPU OutOfMemory: could not split input " + - s"target batch size to less than $minTargetBatchSizeMiB MiB") + throw new SplitAndRetryOOM("GPU OutOfMemory: could not split input " + + s"target batch size to less than $minTargetBatchSizeMiB MiB") } ret } + + def throwIfRebaseNeededInExceptionMode(table: Table, dateRebaseMode: DateTimeRebaseMode, + timestampRebaseMode: DateTimeRebaseMode): Unit = { + (0 until table.getNumberOfColumns).foreach { i => + val col = table.getColumn(i) + if (dateRebaseMode == DateTimeRebaseException && + DateTimeRebaseUtils.isDateRebaseNeededInRead(col)) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet") + } else if (timestampRebaseMode == DateTimeRebaseException && + DateTimeRebaseUtils.isTimeRebaseNeededInRead(col)) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet") + } + } + } + + def rebaseDateTime(table: Table, dateRebaseMode: DateTimeRebaseMode, + timestampRebaseMode: DateTimeRebaseMode): Table = { + val dateRebaseNeeded = dateRebaseMode == DateTimeRebaseLegacy + val timeRebaseNeeded = timestampRebaseMode == DateTimeRebaseLegacy + + lazy val tableHasDate = (0 until table.getNumberOfColumns).exists { i => + checkTypeRecursively(table.getColumn(i), { dt => dt == DType.TIMESTAMP_DAYS }) + } + lazy val tableHasTimestamp = (0 until table.getNumberOfColumns).exists { i => + checkTypeRecursively(table.getColumn(i), { dt => dt == DType.TIMESTAMP_MICROSECONDS }) + } + + if ((dateRebaseNeeded && tableHasDate) || (timeRebaseNeeded && tableHasTimestamp)) { + // Need to close the input table when returning a new table. + withResource(table) { tmpTable => + val newColumns = (0 until tmpTable.getNumberOfColumns).map { i => + deepTransformRebaseDateTime(tmpTable.getColumn(i), dateRebaseNeeded, timeRebaseNeeded) + } + withResource(newColumns) { newCols => + new Table(newCols: _*) + } + } + } else { + table + } + } + + private def checkTypeRecursively(input: ColumnView, f: DType => Boolean): Boolean = { + val dt = input.getType + if (dt.isTimestampType && dt != DType.TIMESTAMP_DAYS && dt != DType.TIMESTAMP_MICROSECONDS) { + // There should be something wrong here since timestamps other than DAYS should already + // been converted into MICROSECONDS when reading Parquet files. + throw new IllegalStateException(s"Unexpected date/time type: $dt " + + "(expected TIMESTAMP_DAYS or TIMESTAMP_MICROSECONDS)") + } + dt match { + case DType.LIST | DType.STRUCT => (0 until input.getNumChildren).exists(i => + withResource(input.getChildColumnView(i)) { child => + checkTypeRecursively(child, f) + }) + case t: DType => f(t) + } + } + + private def deepTransformRebaseDateTime(cv: ColumnVector, dateRebaseNeeded: Boolean, + timeRebaseNeeded: Boolean): ColumnVector = { + ColumnCastUtil.deepTransform(cv) { + case (cv, _) if cv.getType.isTimestampType => + // cv type is guaranteed to be either TIMESTAMP_DAYS or TIMESTAMP_MICROSECONDS, + // since we already checked it in `checkTypeRecursively`. + if ((cv.getType == DType.TIMESTAMP_DAYS && dateRebaseNeeded) || + cv.getType == DType.TIMESTAMP_MICROSECONDS && timeRebaseNeeded) { + DateTimeRebase.rebaseJulianToGregorian(cv) + } else { + cv.copyToColumnVector() + } + } + } } // contains meta about all the blocks in a file @@ -2627,7 +2687,7 @@ object MakeParquetTableProducer extends Logging { logWarning(s"Wrote data for ${splits.mkString(", ")} to $p") } } - GpuParquetScan.throwIfRebaseNeeded(table, dateRebaseMode, + GpuParquetScan.throwIfRebaseNeededInExceptionMode(table, dateRebaseMode, timestampRebaseMode) if (readDataSchema.length < table.getNumberOfColumns) { throw new QueryExecutionException(s"Expected ${readDataSchema.length} columns " + @@ -2635,9 +2695,11 @@ object MakeParquetTableProducer extends Logging { } } metrics(NUM_OUTPUT_BATCHES) += 1 - val ret = ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table, + val evolvedSchemaTable = ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table, clippedParquetSchema, readDataSchema, isSchemaCaseSensitive, useFieldId) - new SingleGpuDataProducer(ret) + val outputTable = GpuParquetScan.rebaseDateTime(evolvedSchemaTable, dateRebaseMode, + timestampRebaseMode) + new SingleGpuDataProducer(outputTable) } } } @@ -2682,15 +2744,16 @@ case class ParquetTableReader( } closeOnExcept(table) { _ => - GpuParquetScan.throwIfRebaseNeeded(table, dateRebaseMode, timestampRebaseMode) + GpuParquetScan.throwIfRebaseNeededInExceptionMode(table, dateRebaseMode, timestampRebaseMode) if (readDataSchema.length < table.getNumberOfColumns) { throw new QueryExecutionException(s"Expected ${readDataSchema.length} columns " + s"but read ${table.getNumberOfColumns} from $splitsString") } } metrics(NUM_OUTPUT_BATCHES) += 1 - ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table, + val evolvedSchemaTable = ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table, clippedParquetSchema, readDataSchema, isSchemaCaseSensitive, useFieldId) + GpuParquetScan.rebaseDateTime(evolvedSchemaTable, dateRebaseMode, timestampRebaseMode) } override def close(): Unit = {