diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index 74f1708aa84..dc959fe64cb 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -18,6 +18,7 @@ from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_cpu_and_gpu_are_equal_sql_with_capture, assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_row_counts_equal, \ assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_spark_exception from data_gen import * +from parquet_write_test import parquet_nested_datetime_gen, parquet_ts_write_options from marks import * import pyarrow as pa import pyarrow.parquet as pa_pq @@ -310,42 +311,16 @@ def test_parquet_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1 lambda spark: rf(spark).select(f.col('a') >= s0), conf=all_confs) - -parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'] - -# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed delete this test and merge it -# into test_parquet_read_roundtrip_datetime -@pytest.mark.parametrize('gen', [ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))), - ArrayGen(ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))))], ids=idfn) -@pytest.mark.parametrize('ts_write', parquet_ts_write_options) -@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY']) -@pytest.mark.parametrize('reader_confs', reader_opt_confs) -@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/1126') -def test_parquet_ts_read_round_trip_nested(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs): - data_path = spark_tmp_path + '/PARQUET_DATA' - with_cpu_session( - lambda spark : unary_op_df(spark, gen).write.parquet(data_path), - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, - 'spark.sql.parquet.outputTimestampType': ts_write}) - all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) - assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path), - conf=all_confs) - -parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, date_gen, timestamp_gen]] - -@pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn) +@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9701') +@pytest.mark.parametrize('parquet_gens', [parquet_nested_datetime_gen], ids=idfn) @pytest.mark.parametrize('ts_type', parquet_ts_write_options) @pytest.mark.parametrize('ts_rebase_write', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')]) @pytest.mark.parametrize('ts_rebase_read', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')]) @pytest.mark.parametrize('reader_confs', reader_opt_confs) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -def test_parquet_read_roundtrip_datetime(spark_tmp_path, parquet_gens, ts_type, - ts_rebase_write, ts_rebase_read, - reader_confs, v1_enabled_list): +def test_parquet_read_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, parquet_gens, ts_type, + ts_rebase_write, ts_rebase_read, + reader_confs, v1_enabled_list): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' write_confs = {'spark.sql.parquet.outputTimestampType': ts_type, diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index dc296b2492f..3e9a8d90f39 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -75,11 +75,12 @@ TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(2000, 1, 1, tzinfo=timezone.utc)) .with_special_case(datetime(1000, 1, 1, tzinfo=timezone.utc), weight=10.0)] -parquet_datetime_in_struct_gen = [StructGen([['child' + str(ind), sub_gen] for ind, sub_gen in enumerate(parquet_datetime_gen_simple)]), - StructGen([['child0', StructGen([['child' + str(ind), sub_gen] for ind, sub_gen in enumerate(parquet_datetime_gen_simple)])]])] -parquet_datetime_in_array_gen = [ArrayGen(sub_gen, max_length=10) for sub_gen in parquet_datetime_gen_simple + parquet_datetime_in_struct_gen] + [ - ArrayGen(ArrayGen(sub_gen, max_length=10), max_length=10) for sub_gen in parquet_datetime_gen_simple + parquet_datetime_in_struct_gen] -parquet_nested_datetime_gen = parquet_datetime_gen_simple + parquet_datetime_in_struct_gen + parquet_datetime_in_array_gen +parquet_datetime_in_struct_gen = [ + StructGen([['child' + str(ind), sub_gen] for ind, sub_gen in enumerate(parquet_datetime_gen_simple)])] +parquet_datetime_in_array_gen = [ArrayGen(sub_gen, max_length=10) for sub_gen in + parquet_datetime_gen_simple + parquet_datetime_in_struct_gen] +parquet_nested_datetime_gen = parquet_datetime_gen_simple + parquet_datetime_in_struct_gen + \ + parquet_datetime_in_array_gen parquet_map_gens = parquet_map_gens_sample + [ MapGen(StructGen([['child0', StringGen()], ['child1', StringGen()]], nullable=False), FloatGen()), @@ -460,15 +461,35 @@ def generate_map_with_empty_validity(spark, path): @datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9701') @pytest.mark.parametrize('data_gen', parquet_nested_datetime_gen, ids=idfn) @pytest.mark.parametrize('ts_write', parquet_ts_write_options) -@pytest.mark.parametrize('ts_rebase_write', ['CORRECTED', 'LEGACY']) -@pytest.mark.parametrize('ts_rebase_read', ['CORRECTED', 'LEGACY']) -def test_datetime_roundtrip_with_legacy_rebase(spark_tmp_path, data_gen, ts_write, ts_rebase_write, ts_rebase_read): +@pytest.mark.parametrize('ts_rebase_write', ['EXCEPTION']) +def test_parquet_write_fails_legacy_datetime(spark_tmp_path, data_gen, ts_write, ts_rebase_write): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write, - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read, - 'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read} + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write} + def writeParquetCatchException(spark, data_gen, data_path): + with pytest.raises(Exception) as e_info: + unary_op_df(spark, data_gen).coalesce(1).write.parquet(data_path) + assert e_info.match(r".*SparkUpgradeException.*") + with_gpu_session( + lambda spark: writeParquetCatchException(spark, data_gen, data_path), + conf=all_confs) + +@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9701') +@pytest.mark.parametrize('data_gen', parquet_nested_datetime_gen, ids=idfn) +@pytest.mark.parametrize('ts_write', parquet_ts_write_options) +@pytest.mark.parametrize('ts_rebase_write', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')]) +@pytest.mark.parametrize('ts_rebase_read', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')]) +def test_parquet_write_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, data_gen, ts_write, + ts_rebase_write, ts_rebase_read): + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, + 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write[0], + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write[1], + # The rebase modes in read configs should be ignored and overridden by the same + # modes in write configs, which are retrieved from the written files. + 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0], + 'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read[1]} assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path), lambda spark, path: spark.read.parquet(path), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index c2249beb3f8..f34737baff8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -158,8 +158,6 @@ object GpuParquetScan { def tagSupport(sparkSession: SparkSession, readSchema: StructType, meta: RapidsMeta[_, _, _]): Unit = { - val sqlConf = sparkSession.conf - if (ParquetLegacyNanoAsLongShims.legacyParquetNanosAsLong) { meta.willNotWorkOnGpu("GPU does not support spark.sql.legacy.parquet.nanosAsLong") } @@ -176,25 +174,6 @@ object GpuParquetScan { FileFormatChecks.tag(meta, readSchema, ParquetFormatType, ReadFileOp) - val schemaHasTimestamps = readSchema.exists { field => - TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) - } - - def isTsOrDate(dt: DataType): Boolean = dt match { - case TimestampType | DateType => true - // Timestamp without timezone (TimestampNTZType, since Spark 3.4) is not yet supported - // See https://github.com/NVIDIA/spark-rapids/issues/9707. - case _ => false - } - - val schemaMightNeedNestedRebase = readSchema.exists { field => - if (DataTypeUtils.isNestedType(field.dataType)) { - TrampolineUtil.dataTypeExistsRecursively(field.dataType, isTsOrDate) - } else { - false - } - } - // Currently timestamp conversion is not supported. // If support needs to be added then we need to follow the logic in Spark's // ParquetPartitionReaderFactory and VectorizedColumnReader which essentially @@ -204,43 +183,12 @@ object GpuParquetScan { // were written in that timezone and convert them to UTC timestamps. // Essentially this should boil down to a vector subtract of the scalar delta // between the configured timezone's delta from UTC on the timestamp data. + val schemaHasTimestamps = readSchema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) + } if (schemaHasTimestamps && sparkSession.sessionState.conf.isParquetINT96TimestampConversion) { meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion") } - - DateTimeRebaseMode.fromName(sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey)) match { - case DateTimeRebaseException => if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION") - } - case DateTimeRebaseCorrected => // Good - case DateTimeRebaseLegacy => - if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY") - } - // This should never be reached out, since invalid mode is handled in - // `DateTimeRebaseMode.fromName`. - case other => meta.willNotWorkOnGpu( - DateTimeRebaseUtils.invalidRebaseModeMessage(other.getClass.getName)) - } - - DateTimeRebaseMode.fromName(sqlConf.get(SparkShimImpl.parquetRebaseReadKey)) match { - case DateTimeRebaseException => if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION") - } - case DateTimeRebaseCorrected => // Good - case DateTimeRebaseLegacy => - if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY") - } - // This should never be reached out, since invalid mode is handled in - // `DateTimeRebaseMode.fromName`. - case other => meta.willNotWorkOnGpu( - DateTimeRebaseUtils.invalidRebaseModeMessage(other.getClass.getName)) - } } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala index 78703104fb4..ebcee60b0fb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import java.util.TimeZone -import ai.rapids.cudf.{ColumnVector, DType, Scalar} +import ai.rapids.cudf.{ColumnView, DType, Scalar} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.shims.SparkShimImpl @@ -117,54 +117,50 @@ object DateTimeRebaseUtils { SPARK_LEGACY_INT96_METADATA_KEY) } - private[this] def isDateRebaseNeeded(column: ColumnVector, - startDay: Int): Boolean = { - // TODO update this for nested column checks - // https://github.com/NVIDIA/spark-rapids/issues/1126 + private[this] def isRebaseNeeded(column: ColumnView, checkType: DType, + minGood: Scalar): Boolean = { val dtype = column.getType - if (dtype == DType.TIMESTAMP_DAYS) { - val hasBad = withResource(Scalar.timestampDaysFromInt(startDay)) { - column.lessThan - } - val anyBad = withResource(hasBad) { - _.any() - } - withResource(anyBad) { _ => - anyBad.isValid && anyBad.getBoolean - } - } else { - false - } - } + require(!dtype.hasTimeResolution || dtype == DType.TIMESTAMP_MICROSECONDS) - private[this] def isTimeRebaseNeeded(column: ColumnVector, - startTs: Long): Boolean = { - val dtype = column.getType - if (dtype.hasTimeResolution) { - require(dtype == DType.TIMESTAMP_MICROSECONDS) - withResource( - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood => + dtype match { + case `checkType` => withResource(column.lessThan(minGood)) { hasBad => - withResource(hasBad.any()) { a => - a.isValid && a.getBoolean + withResource(hasBad.any()) { anyBad => + anyBad.isValid && anyBad.getBoolean } } - } - } else { - false + + case DType.LIST | DType.STRUCT => (0 until column.getNumChildren).exists(i => + withResource(column.getChildColumnView(i)) { child => + isRebaseNeeded(child, checkType, minGood) + }) + + case _ => false + } + } + + private[this] def isDateRebaseNeeded(column: ColumnView, startDay: Int): Boolean = { + withResource(Scalar.timestampDaysFromInt(startDay)) { minGood => + isRebaseNeeded(column, DType.TIMESTAMP_DAYS, minGood) + } + } + + private[this] def isTimeRebaseNeeded(column: ColumnView, startTs: Long): Boolean = { + withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood => + isRebaseNeeded(column, DType.TIMESTAMP_MICROSECONDS, minGood) } } - def isDateRebaseNeededInRead(column: ColumnVector): Boolean = + def isDateRebaseNeededInRead(column: ColumnView): Boolean = isDateRebaseNeeded(column, RebaseDateTime.lastSwitchJulianDay) - def isTimeRebaseNeededInRead(column: ColumnVector): Boolean = + def isTimeRebaseNeededInRead(column: ColumnView): Boolean = isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchJulianTs) - def isDateRebaseNeededInWrite(column: ColumnVector): Boolean = + def isDateRebaseNeededInWrite(column: ColumnView): Boolean = isDateRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianDay) - def isTimeRebaseNeededInWrite(column: ColumnVector): Boolean = + def isTimeRebaseNeededInWrite(column: ColumnView): Boolean = isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianTs) def newRebaseExceptionInRead(format: String): Exception = {