Skip to content

Commit

Permalink
Fix non-nullable under nullable struct write (#11781)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Dec 17, 2024
1 parent 7e465d8 commit 9465328
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 10 deletions.
24 changes: 24 additions & 0 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,30 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl):
data_path,
conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True})

@pytest.mark.parametrize('gen', [ByteGen(nullable=False),
ShortGen(nullable=False),
IntegerGen(nullable=False),
LongGen(nullable=False),
FloatGen(nullable=False),
DoubleGen(nullable=False),
BooleanGen(nullable=False),
StringGen(nullable=False),
StructGen([('b', LongGen(nullable=False))], nullable=False)], ids=idfn)
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
@allow_non_gpu(*non_utc_allow)
def test_write_round_trip_nullable_struct(spark_tmp_path, gen, orc_impl):
gen_for_struct = StructGen([('c', gen)], nullable=True)
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen_for_struct, num_slices=1).write.orc(path),
lambda spark, path: spark.read.orc(path),
data_path,
conf={'spark.sql.orc.impl': orc_impl,
'spark.rapids.sql.format.orc.write.enabled': True,
# https://github.com/NVIDIA/spark-rapids/issues/11736, so verify that we still do it correctly
# once this is fixed
'spark.rapids.sql.format.orc.write.boolType.enabled' : True})

orc_part_write_gens = [
# Add back boolean_gen when https://github.com/rapidsai/cudf/issues/6763 is fixed
byte_gen, short_gen, int_gen, long_gen,
Expand Down
19 changes: 19 additions & 0 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ def test_write_round_trip(spark_tmp_path, parquet_gens):
data_path,
conf=writer_confs)

@pytest.mark.parametrize('gen', [ByteGen(nullable=False),
ShortGen(nullable=False),
IntegerGen(nullable=False),
LongGen(nullable=False),
FloatGen(nullable=False),
DoubleGen(nullable=False),
BooleanGen(nullable=False),
StringGen(nullable=False),
StructGen([('b', LongGen(nullable=False))], nullable=False)], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_write_round_trip_nullable_struct(spark_tmp_path, gen):
gen_for_struct = StructGen([('c', gen)], nullable=True)
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen_for_struct, num_slices=1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=writer_confs)

all_nulls_string_gen = SetValuesGen(StringType(), [None])
empty_or_null_string_gen = SetValuesGen(StringType(), [None, ""])
all_empty_string_gen = SetValuesGen(StringType(), [""])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ class GpuParquetWriter(
val writeContext = new ParquetWriteSupport().init(conf)
val builder = SchemaUtils
.writerOptionsFromSchema(ParquetWriterOptions.builder(), dataSchema,
nullable = false,
ParquetOutputTimestampType.INT96 == SQLConf.get.parquetOutputTimestampType,
parquetFieldIdEnabled)
.withMetadata(writeContext.getExtraMetaData)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,7 +430,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer {
schema: StructType): ParquetWriterOptions = {
val compressionType = if (useCompression) CompressionType.SNAPPY else CompressionType.NONE
SchemaUtils
.writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, writeInt96 = false)
.writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, nullable = false,
writeInt96 = false)
.withCompressionType(compressionType)
.withStatisticsFrequency(StatisticsFrequency.ROWGROUP).build()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -247,27 +247,29 @@ object SchemaUtils {

dataType match {
case dt: DecimalType =>
if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
builder.withDecimalColumn(name, dt.precision, nullable, parquetFieldId.get)
} else {
builder.withDecimalColumn(name, dt.precision, nullable)
}
case TimestampType =>
if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
builder.withTimestampColumn(name, writeInt96, nullable, parquetFieldId.get)
} else {
builder.withTimestampColumn(name, writeInt96, nullable)
}
case s: StructType =>
val structB = if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
val structB = if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
structBuilder(name, nullable, parquetFieldId.get)
} else {
structBuilder(name, nullable)
}
builder.withStructColumn(writerOptionsFromSchema(
structB,
s,
writeInt96, parquetFieldIdWriteEnabled).build())
nullable = nullable,
writeInt96,
parquetFieldIdWriteEnabled).build())
case a: ArrayType =>
builder.withListColumn(
writerOptionsFromField(
Expand Down Expand Up @@ -328,11 +330,14 @@ object SchemaUtils {
def writerOptionsFromSchema[T <: NestedBuilder[T, V], V <: ColumnWriterOptions](
builder: NestedBuilder[T, V],
schema: StructType,
nullable: Boolean,
writeInt96: Boolean = false,
parquetFieldIdEnabled: Boolean = false): T = {
schema.foreach(field =>
writerOptionsFromField(builder, field.dataType, field.name, field.nullable, writeInt96,
field.metadata, parquetFieldIdEnabled)
// CUDF has issues if the child of a struct is not-nullable, but the struct itself is
// So we have to work around it and tell CUDF what it expects.
writerOptionsFromField(builder, field.dataType, field.name, nullable || field.nullable,
writeInt96, field.metadata, parquetFieldIdWriteEnabled = parquetFieldIdEnabled)
)
builder.asInstanceOf[T]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class GpuHiveParquetWriter(override val path: String, dataSchema: StructType,
override protected val tableWriter: CudfTableWriter = {
val optionsBuilder = SchemaUtils
.writerOptionsFromSchema(ParquetWriterOptions.builder(), dataSchema,
nullable = false,
writeInt96 = true, // Hive 1.2 write timestamp as INT96
parquetFieldIdEnabled = false)
.withCompressionType(compType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class GpuOrcWriter(override val path: String,

override val tableWriter: TableWriter = {
val builder = SchemaUtils
.writerOptionsFromSchema(ORCWriterOptions.builder(), dataSchema)
.writerOptionsFromSchema(ORCWriterOptions.builder(), dataSchema, nullable = false)
.withCompressionType(CompressionType.valueOf(OrcConf.COMPRESS.getString(conf)))
Table.writeORCChunked(builder.build(), this)
}
Expand Down

0 comments on commit 9465328

Please sign in to comment.