diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala index 41244e20c369f..f38e188ed042c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala @@ -73,16 +73,21 @@ case object SparkShreddingUtils { */ def variantShreddingSchema(dataType: DataType, isTopLevel: Boolean = true): StructType = { val fields = dataType match { - case ArrayType(elementType, containsNull) => + case ArrayType(elementType, _) => + // Always set containsNull to false. One of value or typed_value must always be set for + // array elements. val arrayShreddingSchema = - ArrayType(variantShreddingSchema(elementType, false), containsNull) + ArrayType(variantShreddingSchema(elementType, false), containsNull = false) Seq( StructField(VariantValueFieldName, BinaryType, nullable = true), StructField(TypedValueFieldName, arrayShreddingSchema, nullable = true) ) case StructType(fields) => + // The field name level is always non-nullable: Variant null values are represented in the + // "value" columna as "00", and missing values are represented by setting both "value" and + // "typed_value" to null. val objectShreddingSchema = StructType(fields.map(f => - f.copy(dataType = variantShreddingSchema(f.dataType, false)))) + f.copy(dataType = variantShreddingSchema(f.dataType, false), nullable = false))) Seq( StructField(VariantValueFieldName, BinaryType, nullable = true), StructField(TypedValueFieldName, objectShreddingSchema, nullable = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala index 4ff346b957aa0..5d5c441052558 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala @@ -155,7 +155,16 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu Row(metadata(Nil), null, Array(Row(null, null)))) checkException(path, "v", "MALFORMED_VARIANT") // Shredded field must not be null. - writeRows(path, writeSchema(StructType.fromDDL("a int")), + // Construct the schema manually, because SparkShreddingUtils.variantShreddingSchema will make + // `a` non-nullable, which would prevent us from writing the file. + val schema = StructType(Seq(StructField("v", StructType(Seq( + StructField("metadata", BinaryType), + StructField("value", BinaryType), + StructField("typed_value", StructType(Seq( + StructField("a", StructType(Seq( + StructField("value", BinaryType), + StructField("typed_value", BinaryType)))))))))))) + writeRows(path, schema, Row(metadata(Seq("a")), null, Row(null))) checkException(path, "v", "MALFORMED_VARIANT") // `value` must not contain any shredded field. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala index a62c6e4462464..d31bf109af6c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala @@ -67,6 +67,36 @@ class VariantWriteShreddingSuite extends SparkFunSuite with ExpressionEvalHelper private val emptyMetadata: Array[Byte] = parseJson("null").getMetadata + test("variantShreddingSchema") { + // Validate the schema produced by SparkShreddingUtils.variantShreddingSchema for a few simple + // cases. + // metadata is always non-nullable. + assert(SparkShreddingUtils.variantShreddingSchema(IntegerType) == + StructType(Seq( + StructField("metadata", BinaryType, nullable = false), + StructField("value", BinaryType, nullable = true), + StructField("typed_value", IntegerType, nullable = true)))) + + val fieldA = StructType(Seq( + StructField("value", BinaryType, nullable = true), + StructField("typed_value", TimestampNTZType, nullable = true))) + val arrayType = ArrayType(StructType(Seq( + StructField("value", BinaryType, nullable = true), + StructField("typed_value", StringType, nullable = true))), containsNull = false) + val fieldB = StructType(Seq( + StructField("value", BinaryType, nullable = true), + StructField("typed_value", arrayType, nullable = true))) + val objectType = StructType(Seq( + StructField("a", fieldA, nullable = false), + StructField("b", fieldB, nullable = false))) + val structSchema = DataType.fromDDL("a timestamp_ntz, b array") + assert(SparkShreddingUtils.variantShreddingSchema(structSchema) == + StructType(Seq( + StructField("metadata", BinaryType, nullable = false), + StructField("value", BinaryType, nullable = true), + StructField("typed_value", objectType, nullable = true)))) + } + test("shredding as fixed numeric types") { /* Cast integer to any wider numeric type. */ testWithSchema("1", IntegerType, Row(emptyMetadata, null, 1))