diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 7e415c79a46..103cae474a3 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -133,6 +133,30 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl): data_path, conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True}) +@pytest.mark.parametrize('gen', [ByteGen(nullable=False), + ShortGen(nullable=False), + IntegerGen(nullable=False), + LongGen(nullable=False), + FloatGen(nullable=False), + DoubleGen(nullable=False), + BooleanGen(nullable=False), + StringGen(nullable=False), + StructGen([('b', LongGen(nullable=False))], nullable=False)], ids=idfn) +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +@allow_non_gpu(*non_utc_allow) +def test_write_round_trip_nullable_struct(spark_tmp_path, gen, orc_impl): + gen_for_struct = StructGen([('c', gen)], nullable=True) + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: unary_op_df(spark, gen_for_struct, num_slices=1).write.orc(path), + lambda spark, path: spark.read.orc(path), + data_path, + conf={'spark.sql.orc.impl': orc_impl, + 'spark.rapids.sql.format.orc.write.enabled': True, + # https://github.com/NVIDIA/spark-rapids/issues/11736, so verify that we still do it correctly + # once this is fixed + 'spark.rapids.sql.format.orc.write.boolType.enabled' : True}) + orc_part_write_gens = [ # Add back boolean_gen when https://github.com/rapidsai/cudf/issues/6763 is fixed byte_gen, short_gen, int_gen, long_gen, diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 1d395d0e29a..e5719d267b4 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -105,6 +105,25 @@ def test_write_round_trip(spark_tmp_path, parquet_gens): data_path, conf=writer_confs) +@pytest.mark.parametrize('gen', [ByteGen(nullable=False), + ShortGen(nullable=False), + IntegerGen(nullable=False), + LongGen(nullable=False), + FloatGen(nullable=False), + DoubleGen(nullable=False), + BooleanGen(nullable=False), + StringGen(nullable=False), + StructGen([('b', LongGen(nullable=False))], nullable=False)], ids=idfn) +@allow_non_gpu(*non_utc_allow) +def test_write_round_trip_nullable_struct(spark_tmp_path, gen): + gen_for_struct = StructGen([('c', gen)], nullable=True) + data_path = spark_tmp_path + '/PARQUET_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: unary_op_df(spark, gen_for_struct, num_slices=1).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=writer_confs) + all_nulls_string_gen = SetValuesGen(StringType(), [None]) empty_or_null_string_gen = SetValuesGen(StringType(), [None, ""]) all_empty_string_gen = SetValuesGen(StringType(), [""]) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index e5aa52c727d..2d6cb903b75 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -391,6 +391,7 @@ class GpuParquetWriter( val writeContext = new ParquetWriteSupport().init(conf) val builder = SchemaUtils .writerOptionsFromSchema(ParquetWriterOptions.builder(), dataSchema, + nullable = false, ParquetOutputTimestampType.INT96 == SQLConf.get.parquetOutputTimestampType, parquetFieldIdEnabled) .withMetadata(writeContext.getExtraMetaData) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ParquetCachedBatchSerializer.scala index d88f21922ce..861905f45f7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ParquetCachedBatchSerializer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ParquetCachedBatchSerializer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -430,7 +430,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer { schema: StructType): ParquetWriterOptions = { val compressionType = if (useCompression) CompressionType.SNAPPY else CompressionType.NONE SchemaUtils - .writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, writeInt96 = false) + .writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, nullable = false, + writeInt96 = false) .withCompressionType(compressionType) .withStatisticsFrequency(StatisticsFrequency.ROWGROUP).build() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SchemaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SchemaUtils.scala index 22047f22e68..cc36fc7c848 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SchemaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SchemaUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -247,19 +247,19 @@ object SchemaUtils { dataType match { case dt: DecimalType => - if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) { + if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) { builder.withDecimalColumn(name, dt.precision, nullable, parquetFieldId.get) } else { builder.withDecimalColumn(name, dt.precision, nullable) } case TimestampType => - if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) { + if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) { builder.withTimestampColumn(name, writeInt96, nullable, parquetFieldId.get) } else { builder.withTimestampColumn(name, writeInt96, nullable) } case s: StructType => - val structB = if(parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) { + val structB = if (parquetFieldIdWriteEnabled && parquetFieldId.nonEmpty) { structBuilder(name, nullable, parquetFieldId.get) } else { structBuilder(name, nullable) @@ -267,7 +267,9 @@ object SchemaUtils { builder.withStructColumn(writerOptionsFromSchema( structB, s, - writeInt96, parquetFieldIdWriteEnabled).build()) + nullable = nullable, + writeInt96, + parquetFieldIdWriteEnabled).build()) case a: ArrayType => builder.withListColumn( writerOptionsFromField( @@ -328,11 +330,14 @@ object SchemaUtils { def writerOptionsFromSchema[T <: NestedBuilder[T, V], V <: ColumnWriterOptions]( builder: NestedBuilder[T, V], schema: StructType, + nullable: Boolean, writeInt96: Boolean = false, parquetFieldIdEnabled: Boolean = false): T = { schema.foreach(field => - writerOptionsFromField(builder, field.dataType, field.name, field.nullable, writeInt96, - field.metadata, parquetFieldIdEnabled) + // CUDF has issues if the child of a struct is not-nullable, but the struct itself is + // So we have to work around it and tell CUDF what it expects. + writerOptionsFromField(builder, field.dataType, field.name, nullable || field.nullable, + writeInt96, field.metadata, parquetFieldIdWriteEnabled = parquetFieldIdEnabled) ) builder.asInstanceOf[T] } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala index d39050a0c32..9c6882ca4a3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala @@ -227,6 +227,7 @@ class GpuHiveParquetWriter(override val path: String, dataSchema: StructType, override protected val tableWriter: CudfTableWriter = { val optionsBuilder = SchemaUtils .writerOptionsFromSchema(ParquetWriterOptions.builder(), dataSchema, + nullable = false, writeInt96 = true, // Hive 1.2 write timestamp as INT96 parquetFieldIdEnabled = false) .withCompressionType(compType) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index 1d4bc66a1da..6e9d30296ff 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -210,7 +210,7 @@ class GpuOrcWriter(override val path: String, override val tableWriter: TableWriter = { val builder = SchemaUtils - .writerOptionsFromSchema(ORCWriterOptions.builder(), dataSchema) + .writerOptionsFromSchema(ORCWriterOptions.builder(), dataSchema, nullable = false) .withCompressionType(CompressionType.valueOf(OrcConf.COMPRESS.getString(conf))) Table.writeORCChunked(builder.build(), this) }