Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-nullable under nullable struct write #11781

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,30 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl):
data_path,
conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True})

@pytest.mark.parametrize('gen', [ByteGen(nullable=False),
ShortGen(nullable=False),
IntegerGen(nullable=False),
LongGen(nullable=False),
FloatGen(nullable=False),
DoubleGen(nullable=False),
BooleanGen(nullable=False),
StringGen(nullable=False),
StructGen([('b', LongGen(nullable=False))], nullable=False)], ids=idfn)
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
@allow_non_gpu(*non_utc_allow)
def test_write_round_trip_nullable_struct(spark_tmp_path, gen, orc_impl):
gen_for_struct = StructGen([('c', gen)], nullable=True)
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen_for_struct, num_slices=1).write.orc(path),
lambda spark, path: spark.read.orc(path),
data_path,
conf={'spark.sql.orc.impl': orc_impl,
'spark.rapids.sql.format.orc.write.enabled': True,
# https://github.com/NVIDIA/spark-rapids/issues/11736, so verify that we still do it correctly
# once this is fixed
'spark.rapids.sql.format.orc.write.boolType.enabled' : True})

orc_part_write_gens = [
# Add back boolean_gen when https://github.com/rapidsai/cudf/issues/6763 is fixed
byte_gen, short_gen, int_gen, long_gen,
Expand Down
19 changes: 19 additions & 0 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ def test_write_round_trip(spark_tmp_path, parquet_gens):
data_path,
conf=writer_confs)

@pytest.mark.parametrize('gen', [ByteGen(nullable=False),
ShortGen(nullable=False),
IntegerGen(nullable=False),
LongGen(nullable=False),
FloatGen(nullable=False),
DoubleGen(nullable=False),
BooleanGen(nullable=False),
StringGen(nullable=False),
StructGen([('b', LongGen(nullable=False))], nullable=False)], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_write_round_trip_nullable_struct(spark_tmp_path, gen):
gen_for_struct = StructGen([('c', gen)], nullable=True)
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen_for_struct, num_slices=1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=writer_confs)

all_nulls_string_gen = SetValuesGen(StringType(), [None])
empty_or_null_string_gen = SetValuesGen(StringType(), [None, ""])
all_empty_string_gen = SetValuesGen(StringType(), [""])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ class GpuParquetWriter(
val writeContext = new ParquetWriteSupport().init(conf)
val builder = SchemaUtils
.writerOptionsFromSchema(ParquetWriterOptions.builder(), dataSchema,
nullable = false,
ParquetOutputTimestampType.INT96 == SQLConf.get.parquetOutputTimestampType,
parquetFieldIdEnabled)
.withMetadata(writeContext.getExtraMetaData)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,7 +430,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer {
schema: StructType): ParquetWriterOptions = {
val compressionType = if (useCompression) CompressionType.SNAPPY else CompressionType.NONE
SchemaUtils
.writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, writeInt96 = false)
.writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, nullable = false,
writeInt96 = false)
.withCompressionType(compressionType)
.withStatisticsFrequency(StatisticsFrequency.ROWGROUP).build()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -247,27 +247,29 @@ object SchemaUtils {

dataType match {
case dt: DecimalType =>
if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
builder.withDecimalColumn(name, dt.precision, nullable, parquetFieldId.get)
} else {
builder.withDecimalColumn(name, dt.precision, nullable)
}
case TimestampType =>
if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
builder.withTimestampColumn(name, writeInt96, nullable, parquetFieldId.get)
} else {
builder.withTimestampColumn(name, writeInt96, nullable)
}
case s: StructType =>
val structB = if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
val structB = if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) {
structBuilder(name, nullable, parquetFieldId.get)
} else {
structBuilder(name, nullable)
}
builder.withStructColumn(writerOptionsFromSchema(
structB,
s,
writeInt96, parquetFieldIdWriteEnabled).build())
nullable = nullable,
writeInt96,
parquetFieldIdWriteEnabled).build())
case a: ArrayType =>
builder.withListColumn(
writerOptionsFromField(
Expand Down Expand Up @@ -328,11 +330,14 @@ object SchemaUtils {
def writerOptionsFromSchema[T <: NestedBuilder[T, V], V <: ColumnWriterOptions](
builder: NestedBuilder[T, V],
schema: StructType,
nullable: Boolean,
writeInt96: Boolean = false,
parquetFieldIdEnabled: Boolean = false): T = {
schema.foreach(field =>
writerOptionsFromField(builder, field.dataType, field.name, field.nullable, writeInt96,
field.metadata, parquetFieldIdEnabled)
// CUDF has issues if the child of a struct is not-nullable, but the struct itself is
// So we have to work around it and tell CUDF what it expects.
writerOptionsFromField(builder, field.dataType, field.name, nullable || field.nullable,
writeInt96, field.metadata, parquetFieldIdWriteEnabled = parquetFieldIdEnabled)
)
builder.asInstanceOf[T]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class GpuHiveParquetWriter(override val path: String, dataSchema: StructType,
override protected val tableWriter: CudfTableWriter = {
val optionsBuilder = SchemaUtils
.writerOptionsFromSchema(ParquetWriterOptions.builder(), dataSchema,
nullable = false,
writeInt96 = true, // Hive 1.2 write timestamp as INT96
parquetFieldIdEnabled = false)
.withCompressionType(compType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class GpuOrcWriter(override val path: String,

override val tableWriter: TableWriter = {
val builder = SchemaUtils
.writerOptionsFromSchema(ORCWriterOptions.builder(), dataSchema)
.writerOptionsFromSchema(ORCWriterOptions.builder(), dataSchema, nullable = false)
.withCompressionType(CompressionType.valueOf(OrcConf.COMPRESS.getString(conf)))
Table.writeORCChunked(builder.build(), this)
}
Expand Down
Loading