From 1807edf4b15149890d6c7f31035cbdb1679bbe1a Mon Sep 17 00:00:00 2001 From: Cindy Jiang Date: Wed, 7 Jun 2023 15:53:01 -0700 Subject: [PATCH 1/7] initial fix for from_json function Signed-off-by: Cindy Jiang --- .../src/main/python/json_test.py | 10 ++++++ .../spark/sql/rapids/GpuJsonToStructs.scala | 33 ++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py index c824789ee55..4c2dd1b475b 100644 --- a/integration_tests/src/main/python/json_test.py +++ b/integration_tests/src/main/python/json_test.py @@ -417,6 +417,16 @@ def test_from_json_struct_of_list(data_gen, schema): .select(f.from_json(f.col('a'), schema)), conf={"spark.rapids.sql.expression.JsonToStructs": True}) +@pytest.mark.parametrize('data_gen', [StringGen('', nullable=True)]) +@pytest.mark.parametrize('schema', [StructType([StructField("a", StringType())]), + StructType([StructField("a", StringType()), StructField("b", IntegerType())]) + ]) +def test_from_json_struct_empty_string_input(data_gen, schema): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen) \ + .select(f.from_json(f.col('a'), schema)), + conf={"spark.rapids.sql.expression.JsonToStructs": True}) + @allow_non_gpu('FileSourceScanExec') @pytest.mark.skipif(is_before_spark_340(), reason='enableDateTimeParsingFallback is supported from Spark3.4.0') @pytest.mark.parametrize('filename,schema', [("dates.json", _date_schema),("dates.json", _timestamp_schema), diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala index 3aa52713aef..59111b9df65 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala @@ -33,12 +33,37 @@ case class GpuJsonToStructs( timeZoneId: Option[String] = None) extends GpuUnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { + + // Construct a dummy json string given an input schema, for example: + // schema = StructType([StructField("a", StringType), StructField("b", IntegerType)]) + // returns "{"a": "", "b": 0}" + private def constructEmptyRow(): String = { + schema match { + case struct: StructType if (struct.fields.length > 0) => { + val res = struct.fields.foldRight ("") ((field, acc) => + field.dataType match { + case IntegerType => "\"" + field.name + "\": 0, " + acc + case _ => "\"" + field.name + "\": \"\", " + acc + } + ) + "{" + res.dropRight(2) + "}" + } + case _ => "{}" + } + } - private def cleanAndConcat(input: cudf.ColumnVector): (cudf.ColumnVector, cudf.ColumnVector) ={ - withResource(cudf.Scalar.fromString("{}")) { emptyRow => - val stripped = withResource(cudf.Scalar.fromString(" ")) { space => - input.strip(space) + private def cleanAndConcat(input: cudf.ColumnVector): (cudf.ColumnVector, cudf.ColumnVector) = { + val emptyRowStr = constructEmptyRow() + withResource(cudf.Scalar.fromString(emptyRowStr)) { emptyRow => + + val stripped = if (input.getData == null) { + input.incRefCount + } else { + withResource(cudf.Scalar.fromString(" ")) { space => + input.strip(space) + } } + withResource(stripped) { stripped => val isNullOrEmptyInput = withResource(input.isNull) { isNull => val isEmpty = withResource(stripped.getCharLengths) { lengths => From f1092fb5fd0a515ac35b17a98db8b44fd90bd9fd Mon Sep 17 00:00:00 2001 From: Cindy Jiang Date: Thu, 8 Jun 2023 15:20:25 -0700 Subject: [PATCH 2/7] cleaned up from json tests and added one test for struct of struct Signed-off-by: Cindy Jiang --- .../src/main/python/json_test.py | 63 +++++++++---------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py index 4c2dd1b475b..c9c1a7eb06a 100644 --- a/integration_tests/src/main/python/json_test.py +++ b/integration_tests/src/main/python/json_test.py @@ -380,51 +380,46 @@ def test_from_json_map_fallback(): 'JsonToStructs', conf={"spark.rapids.sql.expression.JsonToStructs": True}) -@pytest.mark.parametrize('data_gen', [StringGen(r'{"a": "[0-9]{0,5}", "b": "[A-Z]{0,5}", "c": 1234}')]) -@pytest.mark.parametrize('schema', [StructType([StructField("a", StringType())]), - StructType([StructField("d", StringType())]), - StructType([StructField("a", StringType()), StructField("b", StringType())]), - StructType([StructField("c", IntegerType()), StructField("a", StringType())]), - StructType([StructField("a", StringType()), StructField("a", StringType())]) +@pytest.mark.parametrize('schema', ['struct', + 'struct', + 'struct', + 'struct', + 'struct', ]) -def test_from_json_struct(data_gen, schema): +def test_from_json_struct(schema): + json_string_gen = StringGen(r'{"a": "[0-9]{0,5}", "b": "[A-Z]{0,5}", "c": 1234}').with_special_pattern('', weight=50) assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen) \ - .select(f.from_json(f.col('a'), schema)), + lambda spark : unary_op_df(spark, json_string_gen) \ + .select(f.from_json('a', schema)), conf={"spark.rapids.sql.expression.JsonToStructs": True}) -@pytest.mark.parametrize('data_gen', [StringGen(r'{"teacher": "Alice", "student": {"name": "Bob", "age": 20}}')]) -@pytest.mark.parametrize('schema', [StructType([StructField("teacher", StringType())]), - StructType([StructField("student", StructType([StructField("name", StringType()), \ - StructField("age", IntegerType())]))])]) -def test_from_json_struct_of_struct(data_gen, schema): +@pytest.mark.parametrize('schema', ['struct', + 'struct>', + 'struct>']) +def test_from_json_struct_of_struct(schema): + json_string_gen = StringGen(r'{"teacher": "Alice", "student": {"name": "Bob", "age": 20}}').with_special_pattern('', weight=50) assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen) \ - .select(f.from_json(f.col('a'), schema)), + lambda spark : unary_op_df(spark, json_string_gen) \ + .select(f.from_json('a', schema)), conf={"spark.rapids.sql.expression.JsonToStructs": True}) -@pytest.mark.parametrize('data_gen', [StringGen(r'{"teacher": "Alice", "student": \[{"name": "Bob", "class": "junior"},' \ - r'{"name": "Charlie", "class": "freshman"}\]}')]) -@pytest.mark.parametrize('schema', [StructType([StructField("teacher", StringType())]), - StructType([StructField("student", ArrayType(StructType([StructField("name", StringType()), \ - StructField("class", StringType())])))]), - StructType([StructField("teacher", StringType()), \ - StructField("student", ArrayType(StructType([StructField("name", StringType()), \ - StructField("class", StringType())])))])]) -def test_from_json_struct_of_list(data_gen, schema): +@pytest.mark.parametrize('schema', ['struct', + 'struct>>', + 'struct>>']) +def test_from_json_struct_of_list(schema): + json_string_gen = StringGen(r'{"teacher": "Alice", "student": \[{"name": "Bob", "class": "junior"},' \ + r'{"name": "Charlie", "class": "freshman"}\]}').with_special_pattern('', weight=50) assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen) \ - .select(f.from_json(f.col('a'), schema)), + lambda spark : unary_op_df(spark, json_string_gen) \ + .select(f.from_json('a', schema)), conf={"spark.rapids.sql.expression.JsonToStructs": True}) -@pytest.mark.parametrize('data_gen', [StringGen('', nullable=True)]) -@pytest.mark.parametrize('schema', [StructType([StructField("a", StringType())]), - StructType([StructField("a", StringType()), StructField("b", IntegerType())]) - ]) -def test_from_json_struct_empty_string_input(data_gen, schema): +@pytest.mark.parametrize('schema', ['struct', 'struct']) +def test_from_json_struct_all_empty_string_input(schema): + json_string_gen = StringGen('') assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen) \ - .select(f.from_json(f.col('a'), schema)), + lambda spark : unary_op_df(spark, json_string_gen) \ + .select(f.from_json('a', schema)), conf={"spark.rapids.sql.expression.JsonToStructs": True}) @allow_non_gpu('FileSourceScanExec') From 2f7e2ad4642dba89f0fb8bf8909a9cdce9337a52 Mon Sep 17 00:00:00 2001 From: Cindy Jiang Date: Thu, 8 Jun 2023 16:13:45 -0700 Subject: [PATCH 3/7] addressed review comments Signed-off-by: Cindy Jiang --- .../spark/sql/rapids/GpuJsonToStructs.scala | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala index 59111b9df65..265c48ae660 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala @@ -34,26 +34,29 @@ case class GpuJsonToStructs( extends GpuUnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { - // Construct a dummy json string given an input schema, for example: - // schema = StructType([StructField("a", StringType), StructField("b", IntegerType)]) - // returns "{"a": "", "b": 0}" - private def constructEmptyRow(): String = { + private def constructEmptyRow(schema: DataType): String = { schema match { case struct: StructType if (struct.fields.length > 0) => { - val res = struct.fields.foldRight ("") ((field, acc) => + val jsonFields = struct.fields.foldRight (Array.empty[String]) ((field, acc) => field.dataType match { - case IntegerType => "\"" + field.name + "\": 0, " + acc - case _ => "\"" + field.name + "\": \"\", " + acc + case IntegerType => s""""${field.name}": 0""" +: acc + case StringType => s""""${field.name}": """"" +: acc + case s: StructType => s""""${field.name}": ${constructEmptyRow(s)}""" +: acc + case a: ArrayType => s""""${field.name}": ${constructEmptyRow(a)}""" +: acc + case t => throw new IllegalArgumentException("GpuJsonToStructs currently" + + s"does not support input schema with type $t.") } ) - "{" + res.dropRight(2) + "}" + jsonFields.mkString("{", ", ", "}") } + case array: ArrayType => s"[${constructEmptyRow(array.elementType)}]" case _ => "{}" } } + + lazy val emptyRowStr = constructEmptyRow(schema) private def cleanAndConcat(input: cudf.ColumnVector): (cudf.ColumnVector, cudf.ColumnVector) = { - val emptyRowStr = constructEmptyRow() withResource(cudf.Scalar.fromString(emptyRowStr)) { emptyRow => val stripped = if (input.getData == null) { @@ -111,14 +114,14 @@ case class GpuJsonToStructs( // Output = [(null, StringType), ("b", StringType), ("a", IntegerType)] private def processFieldNames(names: Seq[(String, DataType)]): Seq[(String, DataType)] = { val zero = (Set.empty[String], Seq.empty[(String, DataType)]) - val (_, res) = names.foldRight(zero) { case ((name, dtype), (existingNames, acc)) => + val (_, resultFields) = names.foldRight (zero) { case ((name, dtype), (existingNames, acc)) => if (existingNames(name)) { (existingNames, (null, dtype) +: acc) } else { (existingNames + name, (name, dtype) +: acc) } } - res + resultFields } // Given a cudf column, return its Spark type @@ -176,8 +179,7 @@ case class GpuJsonToStructs( } // process duplicated field names in input struct schema - val fieldNames = processFieldNames(struct.fields.map { field => - (field.name, field.dataType)}) + val fieldNames = processFieldNames(struct.fields.map (f => (f.name, f.dataType))) withResource(rawTable) { rawTable => // Step 5: verify that the data looks correct From 5674788e9c41f63b759e31e7b4f974d50919de6c Mon Sep 17 00:00:00 2001 From: Cindy Jiang Date: Thu, 8 Jun 2023 16:58:25 -0700 Subject: [PATCH 4/7] add integration tests and workaround for lstrip and rstrip call sites Signed-off-by: Cindy Jiang --- integration_tests/src/main/python/string_test.py | 12 ++++++------ .../scala/com/nvidia/spark/rapids/GpuOrcScan.scala | 6 +++++- .../apache/spark/sql/rapids/stringFunctions.scala | 12 ++++++++++-- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 5089e997859..88d1c382b70 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -158,20 +158,20 @@ def test_trim(): 'TRIM(BOTH NULL FROM a)', 'TRIM("" FROM a)')) -def test_ltrim(): - gen = mk_str_gen('[Ab \ud720]{0,3}A.{0,3}Z[ Ab]{0,3}') +@pytest.mark.parametrize('data_gen', [mk_str_gen('[Ab \ud720]{0,3}A.{0,3}Z[ Ab]{0,3}'), StringGen('')]) +def test_ltrim(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( + lambda spark: unary_op_df(spark, data_gen).selectExpr( 'LTRIM(a)', 'LTRIM("Ab", a)', 'TRIM(LEADING "A\ud720" FROM a)', 'TRIM(LEADING NULL FROM a)', 'TRIM(LEADING "" FROM a)')) -def test_rtrim(): - gen = mk_str_gen('[Ab \ud720]{0,3}A.{0,3}Z[ Ab]{0,3}') +@pytest.mark.parametrize('data_gen', [mk_str_gen('[Ab \ud720]{0,3}A.{0,3}Z[ Ab]{0,3}'), StringGen('')]) +def test_rtrim(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( + lambda spark: unary_op_df(spark, data_gen).selectExpr( 'RTRIM(a)', 'RTRIM("Ab", a)', 'TRIM(TRAILING "A\ud720" FROM a)', diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index 163a238da54..a8b0dc107c4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -383,7 +383,11 @@ object GpuOrcScan { case (DType.STRING, DType.STRING) if originalFromDt.isInstanceOf[CharType] => // Trim trailing whitespace off of output strings, to match CPU output. - col.rstrip() + if (col.getData == null) { + col.copyToColumnVector + } else { + col.rstrip() + } // TODO more types, tracked in https://github.com/NVIDIA/spark-rapids/issues/5895 case (f, t) => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 55550a3d860..5d3f422d218 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -283,7 +283,11 @@ case class GpuStringTrimLeft(column: Expression, trimParameters: Option[Expressi val trimMethod = "gpuTrimLeft" override def strippedColumnVector(column: GpuColumnVector, t: Scalar): GpuColumnVector = - GpuColumnVector.from(column.getBase.lstrip(t), dataType) + if (column.getBase.getData == null) { + GpuColumnVector.from(column.getBase.incRefCount, dataType) + } else { + GpuColumnVector.from(column.getBase.lstrip(t), dataType) + } } case class GpuStringTrimRight(column: Expression, trimParameters: Option[Expression] = None) @@ -303,7 +307,11 @@ case class GpuStringTrimRight(column: Expression, trimParameters: Option[Express val trimMethod = "gpuTrimRight" override def strippedColumnVector(column:GpuColumnVector, t:Scalar): GpuColumnVector = - GpuColumnVector.from(column.getBase.rstrip(t), dataType) + if (column.getBase.getData == null) { + GpuColumnVector.from(column.getBase.incRefCount, dataType) + } else { + GpuColumnVector.from(column.getBase.rstrip(t), dataType) + } } case class GpuConcatWs(children: Seq[Expression]) From b07674d4d91e1593e40562450eee25fe9f98adb5 Mon Sep 17 00:00:00 2001 From: Cindy Jiang Date: Thu, 8 Jun 2023 17:29:39 -0700 Subject: [PATCH 5/7] add integration tests and fix for GpuStringTrim Signed-off-by: Cindy Jiang --- integration_tests/src/main/python/string_test.py | 6 +++--- .../scala/org/apache/spark/sql/rapids/stringFunctions.scala | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 88d1c382b70..f39f10130c5 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -148,10 +148,10 @@ def test_contains(): f.col('a').contains(None) )) -def test_trim(): - gen = mk_str_gen('[Ab \ud720]{0,3}A.{0,3}Z[ Ab]{0,3}') +@pytest.mark.parametrize('data_gen', [mk_str_gen('[Ab \ud720]{0,3}A.{0,3}Z[ Ab]{0,3}'), StringGen('')]) +def test_trim(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).selectExpr( + lambda spark: unary_op_df(spark, data_gen).selectExpr( 'TRIM(a)', 'TRIM("Ab" FROM a)', 'TRIM("A\ud720" FROM a)', diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 5d3f422d218..cf306c5ee02 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -263,7 +263,11 @@ case class GpuStringTrim(column: Expression, trimParameters: Option[Expression] val trimMethod = "gpuTrim" override def strippedColumnVector(column: GpuColumnVector, t: Scalar): GpuColumnVector = - GpuColumnVector.from(column.getBase.strip(t), dataType) + if (column.getBase.getData == null) { + GpuColumnVector.from(column.getBase.incRefCount, dataType) + } else { + GpuColumnVector.from(column.getBase.strip(t), dataType) + } } case class GpuStringTrimLeft(column: Expression, trimParameters: Option[Expression] = None) From f6583c5adad986ba2a9ceff4f24e340460565ba9 Mon Sep 17 00:00:00 2001 From: Cindy Jiang Date: Thu, 8 Jun 2023 17:44:18 -0700 Subject: [PATCH 6/7] revert a fix for rstrip Signed-off-by: Cindy Jiang --- .../src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index a8b0dc107c4..163a238da54 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -383,11 +383,7 @@ object GpuOrcScan { case (DType.STRING, DType.STRING) if originalFromDt.isInstanceOf[CharType] => // Trim trailing whitespace off of output strings, to match CPU output. - if (col.getData == null) { - col.copyToColumnVector - } else { - col.rstrip() - } + col.rstrip() // TODO more types, tracked in https://github.com/NVIDIA/spark-rapids/issues/5895 case (f, t) => From 9a003c358d7a2eddd9a316d4b729f2dac36812aa Mon Sep 17 00:00:00 2001 From: Cindy Jiang Date: Thu, 8 Jun 2023 18:44:53 -0700 Subject: [PATCH 7/7] addressed review comments and updated tests Signed-off-by: Cindy Jiang --- integration_tests/src/main/python/json_test.py | 14 ++++++++------ .../apache/spark/sql/rapids/GpuJsonToStructs.scala | 12 ++++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py index c9c1a7eb06a..db9e0dfc76f 100644 --- a/integration_tests/src/main/python/json_test.py +++ b/integration_tests/src/main/python/json_test.py @@ -387,7 +387,7 @@ def test_from_json_map_fallback(): 'struct', ]) def test_from_json_struct(schema): - json_string_gen = StringGen(r'{"a": "[0-9]{0,5}", "b": "[A-Z]{0,5}", "c": 1234}').with_special_pattern('', weight=50) + json_string_gen = StringGen(r'{"a": "[0-9]{0,5}", "b": "[A-Z]{0,5}", "c": 1\d\d\d}').with_special_pattern('', weight=50) assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.from_json('a', schema)), @@ -397,18 +397,20 @@ def test_from_json_struct(schema): 'struct>', 'struct>']) def test_from_json_struct_of_struct(schema): - json_string_gen = StringGen(r'{"teacher": "Alice", "student": {"name": "Bob", "age": 20}}').with_special_pattern('', weight=50) + json_string_gen = StringGen(r'{"teacher": "[A-Z]{1}[a-z]{2,5}",' \ + r'"student": {"name": "[A-Z]{1}[a-z]{2,5}", "age": 1\d}}').with_special_pattern('', weight=50) assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.from_json('a', schema)), conf={"spark.rapids.sql.expression.JsonToStructs": True}) @pytest.mark.parametrize('schema', ['struct', - 'struct>>', - 'struct>>']) + 'struct>>', + 'struct>>']) def test_from_json_struct_of_list(schema): - json_string_gen = StringGen(r'{"teacher": "Alice", "student": \[{"name": "Bob", "class": "junior"},' \ - r'{"name": "Charlie", "class": "freshman"}\]}').with_special_pattern('', weight=50) + json_string_gen = StringGen(r'{"teacher": "[A-Z]{1}[a-z]{2,5}",' \ + r'"student": \[{"name": "[A-Z]{1}[a-z]{2,5}", "class": "junior"},' \ + r'{"name": "[A-Z]{1}[a-z]{2,5}", "class": "freshman"}\]}').with_special_pattern('', weight=50) assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.from_json('a', schema)), diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala index 265c48ae660..af9cfc8f717 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala @@ -37,16 +37,16 @@ case class GpuJsonToStructs( private def constructEmptyRow(schema: DataType): String = { schema match { case struct: StructType if (struct.fields.length > 0) => { - val jsonFields = struct.fields.foldRight (Array.empty[String]) ((field, acc) => + val jsonFields: Array[String] = struct.fields.map { field => field.dataType match { - case IntegerType => s""""${field.name}": 0""" +: acc - case StringType => s""""${field.name}": """"" +: acc - case s: StructType => s""""${field.name}": ${constructEmptyRow(s)}""" +: acc - case a: ArrayType => s""""${field.name}": ${constructEmptyRow(a)}""" +: acc + case IntegerType => s""""${field.name}": 0""" + case StringType => s""""${field.name}": """"" + case s: StructType => s""""${field.name}": ${constructEmptyRow(s)}""" + case a: ArrayType => s""""${field.name}": ${constructEmptyRow(a)}""" case t => throw new IllegalArgumentException("GpuJsonToStructs currently" + s"does not support input schema with type $t.") } - ) + } jsonFields.mkString("{", ", ", "}") } case array: ArrayType => s"[${constructEmptyRow(array.elementType)}]"