Skip to content

Commit

Permalink
Escape quotes and newlines when converting strings to json format in …
Browse files Browse the repository at this point in the history
…to_json (#9612)

* escape quotes in when converting strings to json format

* move withResource earlier

* signoff

Signed-off-by: Andy Grove <[email protected]>

* update compatibility guide

* Escape newlines

* address feedback

* add link to issue

---------

Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
andygrove authored Nov 14, 2023
1 parent df9fb5a commit 36baf45
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
2 changes: 0 additions & 2 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,6 @@ with Spark, and can be enabled by setting `spark.rapids.sql.expression.StructsTo

Known issues are:

- String escaping is not implemented, so strings containing quotes, newlines, and other special characters will
not produce valid JSON
- There is no support for timestamp types
- There can be rounding differences when formatting floating-point numbers as strings. For example, Spark may
produce `-4.1243574E26` but the GPU may produce `-4.124357351E26`.
Expand Down
10 changes: 8 additions & 2 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,14 @@ def test_read_case_col_name(spark_tmp_path, v1_enabled_list, col_name):
pytest.param(double_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9350')),
pytest.param(date_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9515')),
pytest.param(timestamp_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9515')),
StringGen('[A-Za-z0-9]{0,10}', nullable=True),
pytest.param(StringGen(nullable=True), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9514')),
StringGen('[A-Za-z0-9\r\n\'"\\\\]{0,10}', nullable=True) \
.with_special_case('\u1f600') \
.with_special_case('"a"') \
.with_special_case('\\"a\\"') \
.with_special_case('\'a\'') \
.with_special_case('\\\'a\\\''),
pytest.param(StringGen('\u001a', nullable=True), marks=pytest.mark.xfail(
reason='https://github.com/NVIDIA/spark-rapids/issues/9705'))
], ids=idfn)
@pytest.mark.parametrize('ignore_null_fields', [True, False])
@pytest.mark.parametrize('pretty', [
Expand Down
24 changes: 17 additions & 7 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,10 @@ object GpuCast {

val numRows = input.getRowCount.toInt

/** Create a new column with quotes around the supplied string column */
/**
* Create a new column with quotes around the supplied string column. Caller
* is responsible for closing `column`.
*/
def addQuotes(column: ColumnVector, rowCount: Int): ColumnVector = {
withResource(ArrayBuffer.empty[ColumnVector]) { columns =>
withResource(Scalar.fromString("\"")) { quote =>
Expand All @@ -922,7 +925,7 @@ object GpuCast {
// keys must have quotes around them in JSON mode
val strKey: ColumnVector = withResource(kvStructColumn.getChildColumnView(0)) { keyColumn =>
withResource(castToString(keyColumn, from.keyType, options)) { key =>
addQuotes(key.incRefCount(), keyColumn.getRowCount.toInt)
addQuotes(key, keyColumn.getRowCount.toInt)
}
}
// string values must have quotes around them in JSON mode, and null values need
Expand All @@ -931,7 +934,7 @@ object GpuCast {
withResource(kvStructColumn.getChildColumnView(1)) { valueColumn =>
val valueStr = if (valueColumn.getType == DType.STRING) {
withResource(castToString(valueColumn, from.valueType, options)) { valueStr =>
addQuotes(valueStr.incRefCount(), valueColumn.getRowCount.toInt)
addQuotes(valueStr, valueColumn.getRowCount.toInt)
}
} else {
castToString(valueColumn, from.valueType, options)
Expand Down Expand Up @@ -1136,7 +1139,7 @@ object GpuCast {
attrValue =>
if (needsQuoting) {
attrValues += quote.incRefCount()
attrValues += escapeJsonString(attrValue.incRefCount())
attrValues += escapeJsonString(attrValue)
attrValues += quote.incRefCount()
withResource(Scalar.fromString("")) { emptyString =>
ColumnVector.stringConcatenate(emptyString, emptyString, attrValues.toArray)
Expand Down Expand Up @@ -1199,10 +1202,17 @@ object GpuCast {
}
}

/**
* Escape quotes and newlines in a string column. Caller is responsible for closing `cv`.
*/
private def escapeJsonString(cv: ColumnVector): ColumnVector = {
// this is a placeholder for implementing string escaping
// https://github.com/NVIDIA/spark-rapids/issues/9514
cv
val chars = Seq("\r", "\n", "\\", "\"")
val escaped = chars.map(StringEscapeUtils.escapeJava)
withResource(ColumnVector.fromStrings(chars: _*)) { search =>
withResource(ColumnVector.fromStrings(escaped: _*)) { replace =>
cv.stringReplace(search, replace)
}
}
}

private[rapids] def castFloatingTypeToString(input: ColumnView): ColumnVector = {
Expand Down

0 comments on commit 36baf45

Please sign in to comment.