diff --git a/core/src/main/resources/supportedExprs.csv b/core/src/main/resources/supportedExprs.csv index c5f273891..d1090ca0d 100644 --- a/core/src/main/resources/supportedExprs.csv +++ b/core/src/main/resources/supportedExprs.csv @@ -222,9 +222,9 @@ GetArrayItem,S, ,None,project,ordinal,NA,S,S,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N GetArrayItem,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS GetArrayStructFields,S, ,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA GetArrayStructFields,S, ,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA -GetJsonObject,S,`get_json_object`,None,project,json,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA -GetJsonObject,S,`get_json_object`,None,project,path,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA -GetJsonObject,S,`get_json_object`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA +GetJsonObject,NS,`get_json_object`,This is disabled by default because escape sequences are not processed correctly; the input is not validated; and the output is not normalized the same as Spark,project,json,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA +GetJsonObject,NS,`get_json_object`,This is disabled by default because escape sequences are not processed correctly; the input is not validated; and the output is not normalized the same as Spark,project,path,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA +GetJsonObject,NS,`get_json_object`,This is disabled by default because escape sequences are not processed correctly; the input is not validated; and the output is not normalized the same as Spark,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA GetMapValue,S, ,None,project,map,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA GetMapValue,S, ,None,project,key,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS GetMapValue,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala index a5a4ae4c2..16a0f4388 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.sql.TrampolineUtil import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{ceil, col, collect_list, count, explode, flatten, floor, hex, json_tuple, round, row_number, sum, translate, xxhash64} +import org.apache.spark.sql.functions.{ceil, col, collect_list, count, explode, flatten, floor, get_json_object, hex, json_tuple, round, row_number, sum, translate, xxhash64} import org.apache.spark.sql.rapids.tool.ToolUtils import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo import org.apache.spark.sql.rapids.tool.util.RapidsToolsConfUtil @@ -847,6 +847,35 @@ class SQLPlanParserSuite extends BaseTestSuite { } } + test("get_json_object is supported in Project") { + // get_json_object is disabled by default in the RAPIDS plugin + TrampolineUtil.withTempDir { parquetoutputLoc => + TrampolineUtil.withTempDir { eventLogDir => + val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, + "Expressions in Generate") { spark => + import spark.implicits._ + val jsonString = + """{"Zipcode":123,"ZipCodeType":"STANDARD", + |"City":"ABCDE","State":"YZ"}""".stripMargin + val data = Seq((1, jsonString)) + val df1 = data.toDF("id", "jValues") + df1.write.parquet(s"$parquetoutputLoc/parquetfile") + val df2 = spark.read.parquet(s"$parquetoutputLoc/parquetfile") + df2.select(col("id"), get_json_object(col("jValues"), "$.ZipCodeType").as("ZipCodeType")) + } + val pluginTypeChecker = new PluginTypeChecker() + val app = createAppFromEventlog(eventLog) + assert(app.sqlPlans.size == 2) + val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => + SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app) + } + val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) + val projectExprs = execInfo.filter(_.exec == "Project") + assertSizeAndNotSupported(1, projectExprs) + } + } + } + test("Expressions supported in SortAggregateExec") { TrampolineUtil.withTempDir { eventLogDir => val aggConfs = Map("spark.sql.execution.useObjectHashAggregateExec" -> "false")