diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonReadCommon.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonReadCommon.scala index 017d9722257..1e4f5579be5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonReadCommon.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonReadCommon.scala @@ -193,18 +193,22 @@ object GpuJsonReadCommon { val allowUnquotedControlChars = options.buildJsonFactory() .isEnabled(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS) + baseCudfJsonOptionsBuilder() + .withNormalizeSingleQuotes(options.allowSingleQuotes) + .withLeadingZeros(options.allowNumericLeadingZeros) + .withNonNumericNumbers(options.allowNonNumericNumbers) + .withUnquotedControlChars(allowUnquotedControlChars) + .build() + } + + def baseCudfJsonOptionsBuilder(): ai.rapids.cudf.JSONOptions.Builder = { ai.rapids.cudf.JSONOptions.builder() .withRecoverWithNull(true) .withMixedTypesAsStrings(true) .withNormalizeWhitespace(true) .withKeepQuotes(true) - .withNormalizeSingleQuotes(options.allowSingleQuotes) .withStrictValidation(true) - .withLeadingZeros(options.allowNumericLeadingZeros) - .withNonNumericNumbers(options.allowNonNumericNumbers) - .withUnquotedControlChars(allowUnquotedControlChars) .withCudfPruneSchema(true) .withExperimental(true) - .build() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/JsonScanRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/JsonScanRetrySuite.scala index 47546f25513..1384a90f3cc 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/JsonScanRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/JsonScanRetrySuite.scala @@ -16,10 +16,10 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.JSONOptions import com.nvidia.spark.rapids.jni.RmmSpark import org.apache.spark.sql.catalyst.json.rapids.JsonPartitionReader +import org.apache.spark.sql.rapids.GpuJsonReadCommon import org.apache.spark.sql.types._ class JsonScanRetrySuite extends RmmSparkRetrySuiteBase { @@ -29,7 +29,7 @@ class JsonScanRetrySuite extends RmmSparkRetrySuiteBase { val cudfSchema = GpuColumnVector.from(StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - val opts = JSONOptions.builder().withLines(true).build() + val opts = GpuJsonReadCommon.baseCudfJsonOptionsBuilder().withLines(true).build() RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, RmmSpark.OomInjectionType.GPU.ordinal, 0) val table = JsonPartitionReader.readToTable(bufferer, cudfSchema, NoopMetric,