Skip to content

Commit

Permalink
Support spark.sql.parquet.datetimeRebaseModeInRead=LEGACY and `spar…
Browse files Browse the repository at this point in the history
…k.sql.parquet.int96RebaseModeInRead=LEGACY` (#9649)

Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia authored Nov 15, 2023
1 parent 36baf45 commit c1c4708
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 95 deletions.
97 changes: 31 additions & 66 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,11 @@ def test_parquet_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1
lambda spark: rf(spark).select(f.col('a') >= s0),
conf=all_confs)

parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']

parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']

# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed delete this test and merge it
# into test_ts_read_round_trip nested timestamps and dates are not supported right now.
# into test_parquet_read_roundtrip_datetime
@pytest.mark.parametrize('gen', [ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),
ArrayGen(ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
Expand All @@ -334,50 +334,36 @@ def test_parquet_ts_read_round_trip_nested(gen, spark_tmp_path, ts_write, ts_reb
lambda spark : spark.read.parquet(data_path),
conf=all_confs)

# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY'])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_ts_read_round_trip(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs):
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : unary_op_df(spark, gen).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_write})
all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.parquet(data_path),
conf=all_confs)

def readParquetCatchException(spark, data_path):
with pytest.raises(Exception) as e_info:
df = spark.read.parquet(data_path).collect()
assert e_info.match(r".*SparkUpgradeException.*")
parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen]]

# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed nested timestamps and dates should be added in
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['LEGACY'])
@pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn)
@pytest.mark.parametrize('ts_type', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase_write', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')])
@pytest.mark.parametrize('ts_rebase_read', [('CORRECTED', 'LEGACY'), ('LEGACY', 'CORRECTED')])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_ts_read_fails_datetime_legacy(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs):
def test_parquet_read_roundtrip_datetime(spark_tmp_path, parquet_gens, ts_type,
ts_rebase_write, ts_rebase_read,
reader_confs, v1_enabled_list):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
write_confs = {'spark.sql.parquet.outputTimestampType': ts_type,
'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write[0],
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write[1]}
with_cpu_session(
lambda spark : unary_op_df(spark, gen).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_write})
all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
with_gpu_session(
lambda spark : readParquetCatchException(spark, data_path),
conf=all_confs)
lambda spark: gen_df(spark, gen_list).write.parquet(data_path),
conf=write_confs)
# The rebase modes in read configs should be ignored and overridden by the same modes in write
# configs, which are retrieved from the written files.
read_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list,
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0],
'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read[1]})
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path),
conf=read_confs)

# This is legacy format, which is totally different from datatime legacy rebase mode.
@pytest.mark.parametrize('parquet_gens', [[byte_gen, short_gen, decimal_gen_32bit], decimal_gens,
[ArrayGen(decimal_gen_32bit, max_length=10)],
[StructGen([['child0', decimal_gen_32bit]])]], ids=idfn)
Expand All @@ -388,32 +374,11 @@ def test_parquet_decimal_read_legacy(spark_tmp_path, parquet_gens, read_func, re
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(data_path),
conf={'spark.sql.parquet.writeLegacyFormat': 'true'})
lambda spark : gen_df(spark, gen_list).write.parquet(data_path),
conf={'spark.sql.parquet.writeLegacyFormat': 'true'})
all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(read_func(data_path), conf=all_confs)


parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens,
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')),
pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))]

@pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn)
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_parquet_read_round_trip_legacy(spark_tmp_path, parquet_gens, v1_enabled_list, reader_confs):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(data_path),
conf=rebase_write_legacy_conf)
all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.parquet(data_path),
conf=all_confs)

@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
@pytest.mark.parametrize('batch_size', [100, INT_MAX])
Expand Down Expand Up @@ -1004,7 +969,7 @@ def test_parquet_reading_from_unaligned_pages_basic_filters_with_nulls(spark_tmp


conf_for_parquet_aggregate_pushdown = {
"spark.sql.parquet.aggregatePushdown": "true",
"spark.sql.parquet.aggregatePushdown": "true",
"spark.sql.sources.useV1SourceList": ""
}

Expand Down Expand Up @@ -1491,15 +1456,15 @@ def test_parquet_read_count(spark_tmp_path):
def test_read_case_col_name(spark_tmp_path, read_func, v1_enabled_list, reader_confs, col_name):
all_confs = copy_and_update(reader_confs, {
'spark.sql.sources.useV1SourceList': v1_enabled_list})
gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)),
gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)),
('k1', LongGen(nullable=False, min_val=1, max_val=1)),
('k2', LongGen(nullable=False, min_val=2, max_val=2)),
('k3', LongGen(nullable=False, min_val=3, max_val=3)),
('v0', LongGen()),
('v1', LongGen()),
('v2', LongGen()),
('v3', LongGen())]

gen = StructGen(gen_list, nullable=False)
data_path = spark_tmp_path + '/PAR_DATA'
reader = read_func(data_path)
Expand Down
121 changes: 92 additions & 29 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import com.nvidia.spark.rapids.ParquetPartitionReader.{CopyRange, LocalCopy}
import com.nvidia.spark.rapids.RapidsConf.ParquetFooterReaderType
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.filecache.FileCache
import com.nvidia.spark.rapids.jni.{ParquetFooter, SplitAndRetryOOM}
import com.nvidia.spark.rapids.jni.{DateTimeRebase, ParquetFooter, SplitAndRetryOOM}
import com.nvidia.spark.rapids.shims.{GpuParquetCrypto, GpuTypeShims, ParquetLegacyNanoAsLongShims, ParquetSchemaClipShims, ParquetStringPredShims, ReaderUtils, ShimFilePartitionReaderFactory, SparkShimImpl}
import org.apache.commons.io.IOUtils
import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream}
Expand Down Expand Up @@ -156,24 +156,7 @@ object GpuParquetScan {
tagSupport(scan.sparkSession, schema, scanMeta)
}

def throwIfRebaseNeeded(table: Table, dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode): Unit = {
(0 until table.getNumberOfColumns).foreach { i =>
val col = table.getColumn(i)
if (dateRebaseMode != DateTimeRebaseCorrected &&
DateTimeRebaseUtils.isDateRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
}
else if (timestampRebaseMode != DateTimeRebaseCorrected &&
DateTimeRebaseUtils.isTimeRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
}
}
}

def tagSupport(
sparkSession: SparkSession,
readSchema: StructType,
def tagSupport(sparkSession: SparkSession, readSchema: StructType,
meta: RapidsMeta[_, _, _]): Unit = {
val sqlConf = sparkSession.conf

Expand All @@ -196,10 +179,14 @@ object GpuParquetScan {
val schemaHasTimestamps = readSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
def isTsOrDate(dt: DataType) : Boolean = dt match {

def isTsOrDate(dt: DataType): Boolean = dt match {
case TimestampType | DateType => true
// Timestamp without timezone (TimestampNTZType, since Spark 3.4) is not yet supported
// See https://github.com/NVIDIA/spark-rapids/issues/9707.
case _ => false
}

val schemaMightNeedNestedRebase = readSchema.exists { field =>
if (DataTypeUtils.isNestedType(field.dataType)) {
TrampolineUtil.dataTypeExistsRecursively(field.dataType, isTsOrDate)
Expand Down Expand Up @@ -316,17 +303,90 @@ object GpuParquetScan {
* @return the updated target batch size.
*/
def splitTargetBatchSize(targetBatchSize: Long, useChunkedReader: Boolean): Long = {
if (!useChunkedReader) {
if (!useChunkedReader) {
throw new SplitAndRetryOOM("GPU OutOfMemory: could not split inputs " +
"chunked parquet reader is configured off")
"chunked parquet reader is configured off")
}
val ret = targetBatchSize / 2
if (targetBatchSize < minTargetBatchSizeMiB * 1024 * 1024) {
throw new SplitAndRetryOOM("GPU OutOfMemory: could not split input " +
s"target batch size to less than $minTargetBatchSizeMiB MiB")
throw new SplitAndRetryOOM("GPU OutOfMemory: could not split input " +
s"target batch size to less than $minTargetBatchSizeMiB MiB")
}
ret
}

def throwIfRebaseNeededInExceptionMode(table: Table, dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode): Unit = {
(0 until table.getNumberOfColumns).foreach { i =>
val col = table.getColumn(i)
if (dateRebaseMode == DateTimeRebaseException &&
DateTimeRebaseUtils.isDateRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
} else if (timestampRebaseMode == DateTimeRebaseException &&
DateTimeRebaseUtils.isTimeRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
}
}
}

def rebaseDateTime(table: Table, dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode): Table = {
val dateRebaseNeeded = dateRebaseMode == DateTimeRebaseLegacy
val timeRebaseNeeded = timestampRebaseMode == DateTimeRebaseLegacy

lazy val tableHasDate = (0 until table.getNumberOfColumns).exists { i =>
checkTypeRecursively(table.getColumn(i), { dt => dt == DType.TIMESTAMP_DAYS })
}
lazy val tableHasTimestamp = (0 until table.getNumberOfColumns).exists { i =>
checkTypeRecursively(table.getColumn(i), { dt => dt == DType.TIMESTAMP_MICROSECONDS })
}

if ((dateRebaseNeeded && tableHasDate) || (timeRebaseNeeded && tableHasTimestamp)) {
// Need to close the input table when returning a new table.
withResource(table) { tmpTable =>
val newColumns = (0 until tmpTable.getNumberOfColumns).map { i =>
deepTransformRebaseDateTime(tmpTable.getColumn(i), dateRebaseNeeded, timeRebaseNeeded)
}
withResource(newColumns) { newCols =>
new Table(newCols: _*)
}
}
} else {
table
}
}

private def checkTypeRecursively(input: ColumnView, f: DType => Boolean): Boolean = {
val dt = input.getType
if (dt.isTimestampType && dt != DType.TIMESTAMP_DAYS && dt != DType.TIMESTAMP_MICROSECONDS) {
// There should be something wrong here since timestamps other than DAYS should already
// been converted into MICROSECONDS when reading Parquet files.
throw new IllegalStateException(s"Unexpected date/time type: $dt " +
"(expected TIMESTAMP_DAYS or TIMESTAMP_MICROSECONDS)")
}
dt match {
case DType.LIST | DType.STRUCT => (0 until input.getNumChildren).exists(i =>
withResource(input.getChildColumnView(i)) { child =>
checkTypeRecursively(child, f)
})
case t: DType => f(t)
}
}

private def deepTransformRebaseDateTime(cv: ColumnVector, dateRebaseNeeded: Boolean,
timeRebaseNeeded: Boolean): ColumnVector = {
ColumnCastUtil.deepTransform(cv) {
case (cv, _) if cv.getType.isTimestampType =>
// cv type is guaranteed to be either TIMESTAMP_DAYS or TIMESTAMP_MICROSECONDS,
// since we already checked it in `checkTypeRecursively`.
if ((cv.getType == DType.TIMESTAMP_DAYS && dateRebaseNeeded) ||
cv.getType == DType.TIMESTAMP_MICROSECONDS && timeRebaseNeeded) {
DateTimeRebase.rebaseJulianToGregorian(cv)
} else {
cv.copyToColumnVector()
}
}
}
}

// contains meta about all the blocks in a file
Expand Down Expand Up @@ -2627,17 +2687,19 @@ object MakeParquetTableProducer extends Logging {
logWarning(s"Wrote data for ${splits.mkString(", ")} to $p")
}
}
GpuParquetScan.throwIfRebaseNeeded(table, dateRebaseMode,
GpuParquetScan.throwIfRebaseNeededInExceptionMode(table, dateRebaseMode,
timestampRebaseMode)
if (readDataSchema.length < table.getNumberOfColumns) {
throw new QueryExecutionException(s"Expected ${readDataSchema.length} columns " +
s"but read ${table.getNumberOfColumns} from ${splits.mkString("; ")}")
}
}
metrics(NUM_OUTPUT_BATCHES) += 1
val ret = ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table,
val evolvedSchemaTable = ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table,
clippedParquetSchema, readDataSchema, isSchemaCaseSensitive, useFieldId)
new SingleGpuDataProducer(ret)
val outputTable = GpuParquetScan.rebaseDateTime(evolvedSchemaTable, dateRebaseMode,
timestampRebaseMode)
new SingleGpuDataProducer(outputTable)
}
}
}
Expand Down Expand Up @@ -2682,15 +2744,16 @@ case class ParquetTableReader(
}

closeOnExcept(table) { _ =>
GpuParquetScan.throwIfRebaseNeeded(table, dateRebaseMode, timestampRebaseMode)
GpuParquetScan.throwIfRebaseNeededInExceptionMode(table, dateRebaseMode, timestampRebaseMode)
if (readDataSchema.length < table.getNumberOfColumns) {
throw new QueryExecutionException(s"Expected ${readDataSchema.length} columns " +
s"but read ${table.getNumberOfColumns} from $splitsString")
}
}
metrics(NUM_OUTPUT_BATCHES) += 1
ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table,
val evolvedSchemaTable = ParquetSchemaUtils.evolveSchemaIfNeededAndClose(table,
clippedParquetSchema, readDataSchema, isSchemaCaseSensitive, useFieldId)
GpuParquetScan.rebaseDateTime(evolvedSchemaTable, dateRebaseMode, timestampRebaseMode)
}

override def close(): Unit = {
Expand Down

0 comments on commit c1c4708

Please sign in to comment.