diff --git a/CHANGELOG.md b/CHANGELOG.md index d534d9398b6..788fed0a98f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Change log -Generated on 2024-06-10 +Generated on 2024-06-13 ## Release 24.06 @@ -48,6 +48,8 @@ Generated on 2024-06-10 ### PRs ||| |:---|:---| +|[#11052](https://github.com/NVIDIA/spark-rapids/pull/11052)|Add spark343 shim for scala2.13 dist jar| +|[#10981](https://github.com/NVIDIA/spark-rapids/pull/10981)|Update latest changelog [skip ci]| |[#10984](https://github.com/NVIDIA/spark-rapids/pull/10984)|[DOC] Update docs for 24.06.0 release [skip ci]| |[#10974](https://github.com/NVIDIA/spark-rapids/pull/10974)|Update rapids JNI and private dependency to 24.06.0| |[#10947](https://github.com/NVIDIA/spark-rapids/pull/10947)|Prevent contains-PrefixRange optimization if not preceded by wildcards| diff --git a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala index 91335afe4e6..14e0d4e0970 100644 --- a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala +++ b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala @@ -16,21 +16,22 @@ package org.apache.spark.sql.tests.datagen +import com.fasterxml.jackson.core.{JsonFactoryBuilder, JsonParser, JsonToken} +import com.fasterxml.jackson.core.json.JsonReadFeature import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} import java.time.{Duration, Instant, LocalDate, LocalDateTime} import java.util - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode import scala.util.Random -import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, XXH64} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{approx_count_distinct, avg, coalesce, col, count, lit, stddev, struct, transform, udf, when} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.XORShiftRandom @@ -79,22 +80,28 @@ class RowLocation(val rowNum: Long, val subRows: Array[Int] = null) { * hash. This makes the generated data correlated for all column/child columns. * @param tableNum a unique ID for the table this is a part of. * @param columnNum the location of the column in the data being generated + * @param substringNum the location of the substring column * @param correlatedKeyGroup the correlated key group this column is a part of, if any. */ -case class ColumnLocation(tableNum: Int, columnNum: Int, correlatedKeyGroup: Option[Long] = None) { - def forNextColumn(): ColumnLocation = ColumnLocation(tableNum, columnNum + 1) +case class ColumnLocation(tableNum: Int, + columnNum: Int, + substringNum: Int, + correlatedKeyGroup: Option[Long] = None) { + def forNextColumn(): ColumnLocation = ColumnLocation(tableNum, columnNum + 1, 0) + def forNextSubstring: ColumnLocation = ColumnLocation(tableNum, columnNum, substringNum + 1) /** * Create a new ColumnLocation that is specifically for a given key group */ def forCorrelatedKeyGroup(keyGroup: Long): ColumnLocation = - ColumnLocation(tableNum, columnNum, Some(keyGroup)) + ColumnLocation(tableNum, columnNum, substringNum, Some(keyGroup)) /** * Hash the location into a single long value. */ - lazy val hashLoc: Long = XXH64.hashLong(tableNum, correlatedKeyGroup.getOrElse(columnNum)) + lazy val hashLoc: Long = XXH64.hashLong(tableNum, + correlatedKeyGroup.getOrElse(XXH64.hashLong(columnNum, substringNum))) } /** @@ -115,6 +122,9 @@ case class ColumnConf(columnLoc: ColumnLocation, def forNextColumn(nullable: Boolean): ColumnConf = ColumnConf(columnLoc.forNextColumn(), nullable, numTableRows) + def forNextSubstring: ColumnConf = + ColumnConf(columnLoc.forNextSubstring, nullable = true, numTableRows) + /** * Create a new configuration based on this, but for a given correlated key group. */ @@ -303,6 +313,23 @@ case class VarLengthGeneratorFunction(minLength: Int, maxLength: Int) extends } } +case class StdDevLengthGen(mean: Double, + stdDev: Double, + mapping: LocationToSeedMapping = null) extends + LengthGeneratorFunction { + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): LengthGeneratorFunction = + StdDevLengthGen(mean, stdDev, mapping) + + override def apply(rowLoc: RowLocation): Int = { + val r = DataGen.getRandomFor(rowLoc, mapping) + val g = r.nextGaussian() // g has a mean of 0 and a stddev of 1.0 + val adjusted = mean + (g * stdDev) + // If the range of seed is too small compared to the stddev and mean we will + // end up with an invalid distribution, but they asked for it. + math.max(0, math.round(adjusted).toInt) + } +} + /** * Generate nulls with a given probability. * @param prob 0.0 to 1.0 for how often nulls should appear in the output. @@ -562,11 +589,8 @@ case class DataGenExpr(child: Expression, } } -/** - * Base class for generating a column/sub-column. This holds configuration for the column, - * and handles what is needed to convert it into GeneratorFunction - */ -abstract class DataGen(var conf: ColumnConf, +abstract class CommonDataGen( + var conf: ColumnConf, defaultValueRange: Option[(Any, Any)], var seedMapping: LocationToSeedMapping = FlatDistribution(), var nullMapping: LocationToSeedMapping = FlatDistribution(), @@ -576,26 +600,25 @@ abstract class DataGen(var conf: ColumnConf, protected var valueRange: Option[(Any, Any)] = defaultValueRange /** - * Set a value range for this data gen. + * Set a value range */ - def setValueRange(min: Any, max: Any): DataGen = { + def setValueRange(min: Any, max: Any): CommonDataGen = { valueRange = Some((min, max)) this } /** - * Set a custom GeneratorFunction to use for this column. + * Set a custom GeneratorFunction */ - def setValueGen(f: GeneratorFunction): DataGen = { + def setValueGen(f: GeneratorFunction): CommonDataGen = { userProvidedValueGen = Some(f) this } /** - * Set a NullGeneratorFunction for this column. This will not be used - * if the column is not nullable. + * Set a NullGeneratorFunction */ - def setNullGen(f: NullGeneratorFunction): DataGen = { + def setNullGen(f: NullGeneratorFunction): CommonDataGen = { this.userProvidedNullGen = Some(f) this } @@ -604,12 +627,12 @@ abstract class DataGen(var conf: ColumnConf, * Set the probability of a null appearing in the output. The probability should be * 0.0 to 1.0. */ - def setNullProbability(probability: Double): DataGen = { + def setNullProbability(probability: Double): CommonDataGen = { this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability)) this } - def setNullProbabilityRecursively(probability: Double): DataGen = { + def setNullProbabilityRecursively(probability: Double): CommonDataGen = { this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability)) children.foreach { case (_, dataGen) => @@ -621,7 +644,7 @@ abstract class DataGen(var conf: ColumnConf, /** * Set a specific location to seed mapping for the value generation. */ - def setSeedMapping(seedMapping: LocationToSeedMapping): DataGen = { + def setSeedMapping(seedMapping: LocationToSeedMapping): CommonDataGen = { this.seedMapping = seedMapping this } @@ -629,7 +652,7 @@ abstract class DataGen(var conf: ColumnConf, /** * Set a specific location to seed mapping for the null generation. */ - def setNullMapping(nullMapping: LocationToSeedMapping): DataGen = { + def setNullMapping(nullMapping: LocationToSeedMapping): CommonDataGen = { this.nullMapping = nullMapping this } @@ -638,7 +661,7 @@ abstract class DataGen(var conf: ColumnConf, * Set a specific LengthGeneratorFunction to use. This will only be used if * the datatype needs a length. */ - def setLengthGen(lengthGen: LengthGeneratorFunction): DataGen = { + def setLengthGen(lengthGen: LengthGeneratorFunction): CommonDataGen = { this.lengthGen = lengthGen this } @@ -646,25 +669,30 @@ abstract class DataGen(var conf: ColumnConf, /** * Set the length generation to be a fixed length. */ - def setLength(len: Int): DataGen = { + def setLength(len: Int): CommonDataGen = { this.lengthGen = FixedLengthGeneratorFunction(len) this } - def setLength(minLen: Int, maxLen: Int) = { + def setLength(minLen: Int, maxLen: Int): CommonDataGen = { this.lengthGen = VarLengthGeneratorFunction(minLen, maxLen) this } + def setGaussianLength(mean: Double, stdDev: Double): CommonDataGen = { + this.lengthGen = StdDevLengthGen(mean, stdDev) + this + } + /** * Add this column to a specific correlated key group. This should not be * called directly by users. */ def setCorrelatedKeyGroup(keyGroup: Long, - minSeed: Long, maxSeed: Long, - seedMapping: LocationToSeedMapping): DataGen = { + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): CommonDataGen = { conf = conf.forCorrelatedKeyGroup(keyGroup) - .forSeedRange(minSeed, maxSeed) + .forSeedRange(minSeed, maxSeed) this.seedMapping = seedMapping this } @@ -672,7 +700,7 @@ abstract class DataGen(var conf: ColumnConf, /** * Set a range of seed values that should be returned by the LocationToSeedMapping */ - def setSeedRange(min: Long, max: Long): DataGen = { + def setSeedRange(min: Long, max: Long): CommonDataGen = { conf = conf.forSeedRange(min, max) this } @@ -681,7 +709,7 @@ abstract class DataGen(var conf: ColumnConf, * Get the default value generator for this specific data gen. */ protected def getValGen: GeneratorFunction - def children: Seq[(String, DataGen)] + def children: Seq[(String, CommonDataGen)] /** * Get the final ready to use GeneratorFunction for the data generator. @@ -690,8 +718,8 @@ abstract class DataGen(var conf: ColumnConf, val sm = seedMapping.withColumnConf(conf) val lg = lengthGen.withLocationToSeedMapping(sm) var valGen = userProvidedValueGen.getOrElse(getValGen) - .withLocationToSeedMapping(sm) - .withLengthGeneratorFunction(lg) + .withLocationToSeedMapping(sm) + .withLengthGeneratorFunction(lg) valueRange.foreach { case (min, max) => valGen = valGen.withValueRange(min, max) @@ -700,35 +728,75 @@ abstract class DataGen(var conf: ColumnConf, val nullColConf = conf.forNulls val nm = nullMapping.withColumnConf(nullColConf) userProvidedNullGen.get - .withWrapped(valGen) - .withLocationToSeedMapping(nm) + .withWrapped(valGen) + .withLocationToSeedMapping(nm) } else { valGen } } - /** - * Get the data type for this column - */ - def dataType: DataType - /** * Is this column nullable or not. */ def nullable: Boolean = conf.nullable /** - * Get a child column for a given name, if it has one. + * Get a child for a given name, if it has one. */ - final def apply(name: String): DataGen = { + final def apply(name: String): CommonDataGen = { get(name).getOrElse{ throw new IllegalStateException(s"Could not find a child $name for $this") } } - def get(name: String): Option[DataGen] = None + def get(name: String): Option[CommonDataGen] = None +} + + +/** + * Base class for generating a column/sub-column. This holds configuration + * for the column, and handles what is needed to convert it into GeneratorFunction + */ +abstract class DataGen( + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)], + seedMapping: LocationToSeedMapping = FlatDistribution(), + nullMapping: LocationToSeedMapping = FlatDistribution(), + lengthGen: LengthGeneratorFunction = FixedLengthGeneratorFunction(10)) extends + CommonDataGen(conf, defaultValueRange, seedMapping, nullMapping, lengthGen) { + + /** + * Get the data type for this column + */ + def dataType: DataType + + override def get(name: String): Option[DataGen] = None + + def getSubstringGen: Option[SubstringDataGen] = None + + def substringGen: SubstringDataGen = + getSubstringGen.getOrElse( + throw new IllegalArgumentException("substring data gen was not set")) + + def setSubstringGen(f : ColumnConf => SubstringDataGen): Unit = + setSubstringGen(Option(f(conf.forNextSubstring))) + + def setSubstringGen(subgen: Option[SubstringDataGen]): Unit = + throw new IllegalArgumentException("substring data gens can only be set for a STRING") } +/** + * Base class for generating a sub-string. This holds configuration + * for the substring, and handles what is needed to convert it into a GeneratorFunction + */ +abstract class SubstringDataGen( + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)], + seedMapping: LocationToSeedMapping = FlatDistribution(), + nullMapping: LocationToSeedMapping = FlatDistribution(), + lengthGen: LengthGeneratorFunction = FixedLengthGeneratorFunction(10)) extends + CommonDataGen(conf, defaultValueRange, seedMapping, nullMapping, lengthGen) {} + /** * A special GeneratorFunction that just returns the computed seed. This is helpful for * debugging distributions or if you want long values without any abstraction in between. @@ -1494,155 +1562,866 @@ class FloatGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) override def children: Seq[(String, DataGen)] = Seq.empty } -trait JSONType { - def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit -} +case class JsonPathElement(name: String, is_array: Boolean) +case class JsonLevel(path: Array[JsonPathElement], data_type: String, length: Int, value: String) {} + +object JsonColumnStats { + private def printHelp(): Unit = { + println("JSON Fingerprinting Tool:") + println("PARAMS: ") + println(" is a path to a Spark dataframe to read in") + println(" is a path in a Spark file system to write out fingerprint data to.") + println() + println("OPTIONS:") + println(" --json= where is the name of a top level String column") + println(" --anon= where is a SEED used to anonymize the JSON keys ") + println(" and column names.") + println(" --input_format= where is parquet or ORC. Defaults to parquet.") + println(" --overwrite to enable overwriting the fingerprint output.") + println(" --debug to enable some debug information to be printed out") + println(" --help to print out this help message") + println() + } + + def main(args: Array[String]): Unit = { + var inputPath = Option.empty[String] + var outputPath = Option.empty[String] + val jsonColumns = ArrayBuffer.empty[String] + var anonSeed = Option.empty[Long] + var debug = false + var argsDone = false + var format = "parquet" + var overwrite = false + + args.foreach { + case a if !argsDone && a.startsWith("--json=") => + jsonColumns += a.substring("--json=".length) + case a if !argsDone && a.startsWith("--anon=") => + anonSeed = Some(a.substring("--anon=".length).toLong) + case a if !argsDone && a.startsWith("--input_format=") => + format = a.substring("--input_format=".length).toLowerCase(java.util.Locale.US) + case "--overwrite" if !argsDone => + overwrite = true + case "--debug" if !argsDone => + debug = true + case "--help" if !argsDone => + printHelp() + System.exit(0) + case "--" if !argsDone => + argsDone = true + case a if !argsDone && a.startsWith("--") => // "--" was covered above already + println(s"ERROR $a is not a supported argument") + printHelp() + System.exit(-1) + case a if inputPath.isEmpty => + inputPath = Some(a) + case a if outputPath.isEmpty => + outputPath = Some(a) + case a => + println(s"ERROR only two arguments are supported. Found $a") + printHelp() + System.exit(-1) + } + if (outputPath.isEmpty) { + println("ERROR both an inputPath and an outputPath are required") + printHelp() + System.exit(-1) + } + + val spark = SparkSession.builder.getOrCreate() + spark.sparkContext.setLogLevel("WARN") + + val df = spark.read.format(format).load(inputPath.get) + jsonColumns.foreach { column => + val fp = fingerPrint(df, df(column), anonSeed) + val name = anonSeed.map(s => anonymizeString(column, s)).getOrElse(column) + val fullOutPath = s"${outputPath.get}/$name" + var writer = fp.write + if (overwrite) { + writer = writer.mode("overwrite") + } + if (debug) { + anonSeed.foreach { s => + println(s"Keys and columns will be anonymized with seed $s") + } + println(s"Writing $column fingerprint to $fullOutPath") + spark.time(writer.parquet(fullOutPath)) + println(s"Wrote ${spark.read.parquet(fullOutPath).count} rows") + spark.read.parquet(fullOutPath).show() + } else { + writer.parquet(fullOutPath) + } + } + } -object JSONType { - def selectType(depth: Int, - maxDepth: Int, - r: Random): JSONType = { - val toSelectFrom = if (depth < maxDepth) { - Seq(QuotedJSONString, JSONLong, JSONDouble, JSONArray, JSONObject) - } else { - Seq(QuotedJSONString, JSONLong, JSONDouble) - } - val index = r.nextInt(toSelectFrom.length) - toSelectFrom(index) - } -} - -object QuotedJSONString extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - val strValue = r.nextString(r.nextInt(maxStringLength + 1)) - .replace("\\", "\\\\") - .replace("\"", "\\\"") - .replace("\n", "\\n") - .replace("\r", "\\r") - .replace("\b", "\\b") - .replace("\f", "\\f") - sb.append('"') - sb.append(strValue) - sb.append('"') - } -} - -object JSONLong extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - sb.append(r.nextLong()) - } -} - -object JSONDouble extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - sb.append(r.nextDouble() * 4096.0) - } -} - -object JSONArray extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - val childType = JSONType.selectType(depth, maxDepth, r) - val length = r.nextInt(maxArrayLength + 1) - sb.append("[") + case class JsonNodeStats(count: Long, meanLen: Double, stdDevLength: Double, dc: Long) + + class JsonNode() { + private val forDataType = + mutable.HashMap[String, (JsonNodeStats, mutable.HashMap[String, JsonNode])]() + + def getChild(name: String, isArray: Boolean): JsonNode = { + val dt = if (isArray) { "ARRAY" } else { "OBJECT" } + val typed = forDataType.getOrElse(dt, + throw new IllegalArgumentException(s"$dt is not a set data type yet.")) + typed._2.getOrElse(name, + throw new IllegalArgumentException(s"$name is not a child when the type is $dt")) + } + + def contains(name: String, isArray: Boolean): Boolean = { + val dt = if (isArray) { "ARRAY" } else { "OBJECT" } + forDataType.get(dt).exists { children => + children._2.contains(name) + } + } + + def addChild(name: String, isArray: Boolean): JsonNode = { + val dt = if (isArray) { "ARRAY" } else { "OBJECT" } + val found = forDataType.getOrElse(dt, + throw new IllegalArgumentException(s"$dt was not already added as a data type")) + if (found._2.contains(name)) { + throw new IllegalArgumentException(s"$dt already has a child named $name") + } + val node = new JsonNode() + found._2.put(name, node) + node + } + + def addChoice(dt: String, stats: JsonNodeStats): Unit = { + if (forDataType.contains(dt)) { + throw new IllegalArgumentException(s"$dt was already added as a data type") + } + forDataType.put(dt, (stats, new mutable.HashMap[String, JsonNode]())) + } + + override def toString: String = { + forDataType.toString() + } + + def totalCount: Long = { + forDataType.values.map{ case (stats, _) => stats.count}.sum + } + + private def makeNoChoiceGenRecursive(dt: String, + children: mutable.HashMap[String, JsonNode], + cc: ColumnConf): (SubstringDataGen, ColumnConf) = { + var c = cc + val ret = dt match { + case "LONG" => new JSONLongGen(c) + case "DOUBLE" => new JSONDoubleGen(c) + case "BOOLEAN" => new JSONBoolGen(c) + case "NULL" => new JSONNullGen(false, c) + case "VALUE_NULL" => new JSONNullGen(true, c) + case "ERROR" => new JSONErrorGen(c) + case "STRING" => new JSONStringGen(c) + case "ARRAY" => + val child = if (children.isEmpty) { + // A corner case, we will just make it a BOOL column and it will be ignored + val tmp = new JSONBoolGen(c) + c = c.forNextSubstring + tmp + } else { + val tmp = children.values.head.makeGenRecursive(c) + c = tmp._2 + tmp._1 + } + new JSONArrayGen(child, c) + case "OBJECT" => + val childGens = if (children.isEmpty) { + Seq.empty + } else { + children.toSeq.map { + case (k, node) => + val tmp = node.makeGenRecursive(c) + c = tmp._2 + (k, tmp._1) + } + } + new JSONObjectGen(childGens, c) + case other => + throw new IllegalArgumentException(s"$other is not a leaf node type") + } + (ret, c.forNextSubstring) + } + + private def makeGenRecursive(cc: ColumnConf): (SubstringDataGen, ColumnConf) = { + var c = cc + // We are going to recursively walk the tree for all of the values. + if (forDataType.size == 1) { + // We don't need a choice at all. This makes it simpler.. + val (dt, (_, children)) = forDataType.head + makeNoChoiceGenRecursive(dt, children, c) + } else { + val totalSum = forDataType.map(f => f._2._1.count).sum.toDouble + var runningSum = 0L + val allChoices = ArrayBuffer[(Double, String, SubstringDataGen)]() + forDataType.foreach { + case (dt, (stats, children)) => + val tmp = makeNoChoiceGenRecursive(dt, children, c) + c = tmp._2 + runningSum += stats.count + allChoices.append((runningSum/totalSum, dt, tmp._1)) + } + + val ret = new JSONChoiceGen(allChoices.toSeq, c) + (ret, c.forNextSubstring) + } + } + + def makeGen(cc: ColumnConf): SubstringDataGen = { + val (ret, _) = makeGenRecursive(cc) + ret + } + + def setStatsSingle(dg: CommonDataGen, + dt: String, + stats: JsonNodeStats, + nullPct: Double): Unit = { + + val includeLength = dt != "OBJECT" && dt != "BOOLEAN" && dt != "NULL" && dt != "VALUE_NULL" + val includeNullPct = nullPct > 0.0 + if (includeLength) { + dg.setGaussianLength(stats.meanLen, stats.stdDevLength) + } + if (includeNullPct) { + dg.setNullProbability(nullPct) + } + dg.setSeedRange(1, stats.dc) + } + + def setStats(dg: CommonDataGen, + parentCount: Option[Long]): Unit = { + // We are going to recursively walk the tree... + if (forDataType.size == 1) { + // We don't need a choice at all. This makes it simpler.. + val (dt, (stats, children)) = forDataType.head + val nullPct = parentCount.map { pc => + (pc - stats.count).toDouble/pc + }.getOrElse(0.0) + setStatsSingle(dg, dt, stats, nullPct) + val myCount = if (dt == "OBJECT") { + Some(totalCount) + } else { + None + } + children.foreach { + case (name, node) => + node.setStats(dg(name), myCount) + } + } else { + // We have choices to make between different types. + // The null percent cannot be calculated for each individual choice + // but is calculated on the group as a whole instead + parentCount.foreach { pc => + val tc = totalCount + val choiceNullPct = (pc - tc).toDouble / pc + if (choiceNullPct > 0.0) { + dg.setNullProbability(choiceNullPct) + } + } + forDataType.foreach { + case (dt, (stats, children)) => + // When there is a choice the name to access it is the data type + val choiceDg = dg(dt) + setStatsSingle(choiceDg, dt, stats, 0.0) + children.foreach { + case (name, node) => + val myCount = if (dt == "OBJECT") { + // Here we only want the count for the OBJECTs + Some(stats.count) + } else { + None + } + node.setStats(choiceDg(name), myCount) + } + } + } + } + } + + private lazy val jsonFactory = new JsonFactoryBuilder() + // The two options below enabled for Hive compatibility + .enable(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS) + .enable(JsonReadFeature.ALLOW_SINGLE_QUOTES) + .build() + + private def processNext(parser: JsonParser, + currentPath: ArrayBuffer[JsonPathElement], + output: ArrayBuffer[JsonLevel]): Unit = { + parser.currentToken() match { + case JsonToken.START_OBJECT => + parser.nextToken() + while (parser.currentToken() != JsonToken.END_OBJECT) { + processNext(parser, currentPath, output) + } + output.append(JsonLevel(currentPath.toArray, "OBJECT", 0, "")) + parser.nextToken() + case JsonToken.START_ARRAY => + currentPath.append(JsonPathElement("data", is_array = true)) + parser.nextToken() + var length = 0 + while (parser.currentToken() != JsonToken.END_ARRAY) { + length += 1 + processNext(parser, currentPath, output) + } + currentPath.remove(currentPath.length - 1) + output.append(JsonLevel(currentPath.toArray, "ARRAY", length, "")) + parser.nextToken() + case JsonToken.FIELD_NAME => + currentPath.append(JsonPathElement(parser.getCurrentName, is_array = false)) + parser.nextToken() + processNext(parser, currentPath, output) + currentPath.remove(currentPath.length - 1) + case JsonToken.VALUE_NUMBER_INT => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "LONG", length, parser.getValueAsString)) + parser.nextToken() + case JsonToken.VALUE_NUMBER_FLOAT => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "DOUBLE", length, parser.getValueAsString)) + parser.nextToken() + case JsonToken.VALUE_TRUE | JsonToken.VALUE_FALSE => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "BOOLEAN", length, parser.getValueAsString)) + parser.nextToken() + case JsonToken.VALUE_NULL | null => + output.append(JsonLevel(currentPath.toArray, "VALUE_NULL", 4, "NULL")) + parser.nextToken() + case JsonToken.VALUE_STRING => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "STRING", length, parser.getValueAsString)) + parser.nextToken() + case other => + throw new IllegalStateException(s"DON'T KNOW HOW TO DEAL WITH $other") + } + } + + def jsonStatsUdf(json: String): Array[JsonLevel] = { + val output = new ArrayBuffer[JsonLevel]() + try { + val currentPath = new ArrayBuffer[JsonPathElement]() + if (json == null) { + output.append(JsonLevel(Array.empty, "NULL", 0, "")) + } else { + val parser = jsonFactory.createParser(json) + try { + parser.nextToken() + processNext(parser, currentPath, output) + } finally { + parser.close() + } + } + } catch { + case _: com.fasterxml.jackson.core.JsonParseException => + output.clear() + output.append(JsonLevel(Array.empty, "ERROR", json.getBytes("UTF-8").length, json)) + } + output.toArray + } + + private lazy val extractPaths = udf(json => jsonStatsUdf(json)) + + def anonymizeString(str: String, seed: Long): String = { + val length = str.length + val data = new Array[Byte](length) + val hash = XXH64.hashLong(str.hashCode, seed) + val r = new Random() + r.setSeed(hash) (0 until length).foreach { i => - if (i > 0) { - sb.append(",") + val tmp = r.nextInt(16) + data(i) = (tmp + 'A').toByte + } + new String(data) + } + + private lazy val anonPath = udf((str, seed) => anonymizeString(str, seed)) + + def anonymizeFingerPrint(df: DataFrame, anonSeed: Long): DataFrame = { + df.withColumn("tmp", transform(col("path"), + o => { + val name = o("name") + val isArray = o("is_array") + val anon = anonPath(name, lit(anonSeed)) + val newName = when(isArray, name).otherwise(anon).alias("name") + struct(newName, isArray) + })) + .drop("path").withColumnRenamed("tmp", "path") + .orderBy("path", "dt") + .selectExpr("path", "dt","c","mean_len","stddev_len","distinct","version") + } + + def fingerPrint(df: DataFrame, column: Column, anonymize: Option[Long] = None): DataFrame = { + val ret = df.select(extractPaths(column).alias("paths")) + .selectExpr("explode_outer(paths) as p") + .selectExpr("p.path as path", "p.data_type as dt", "p.length as len", "p.value as value") + .groupBy(col("path"), col("dt")).agg( + count(lit(1)).alias("c"), + avg(col("len")).alias("mean_len"), + coalesce(stddev(col("len")), lit(0.0)).alias("stddev_len"), + approx_count_distinct(col("value")).alias("distinct")) + .orderBy("path", "dt").withColumn("version", lit("0.1")) + .selectExpr("path", "dt","c","mean_len","stddev_len","distinct","version") + + anonymize.map { anonSeed => + anonymizeFingerPrint(ret, anonSeed) + }.getOrElse(ret) + } + + def apply(aggForColumn: DataFrame, genColumn: ColumnGen): Unit = + apply(aggForColumn, genColumn.dataGen) + + private val expectedSchema = StructType.fromDDL( + "path ARRAY>," + + "dt STRING," + + "c BIGINT," + + "mean_len DOUBLE," + + "stddev_len DOUBLE," + + "distinct BIGINT," + + "version STRING") + + def apply(aggForColumn: DataFrame, gen: DataGen): Unit = { + val aggData = aggForColumn.orderBy("path", "dt").collect() + val rootNode: JsonNode = new JsonNode() + assert(aggData.length > 0) + val schema = aggData.head.schema + assert(schema.length == expectedSchema.length) + schema.fields.zip(expectedSchema.fields).foreach { + case(found, expected) => + assert(found.name == expected.name) + // TODO we can worry about the exact types later if we need to + } + assert(aggData.head.getString(6) == "0.1") + aggData.foreach { row => + val fullPath = row.getAs[mutable.WrappedArray[Row]](0) + val parsedPath = fullPath.map(r => (r.getString(0), r.getBoolean(1))).toList + val dt = row.getString(1) + val count = row.getLong(2) + val meanLen = row.getDouble(3) + val stdLen = row.getDouble(4) + val dc = row.getLong(5) + + val stats = JsonNodeStats(count, meanLen, stdLen, dc) + var currentNode = rootNode + // Find everything up to the last path element + if (parsedPath.length > 1) { + parsedPath.slice(0, parsedPath.length - 1).foreach { + case (name, isArray) => + currentNode = currentNode.getChild(name, isArray) + } + } + + if (parsedPath.nonEmpty) { + // For the last path element (that is not the root element) we might need to add it + // as a child + val (name, isArray) = parsedPath.last + if (!currentNode.contains(name, isArray)) { + currentNode.addChild(name, isArray) + } + currentNode = currentNode.getChild(name, isArray) } - childType.appendRandomValue(sb, i, maxStringLength, maxArrayLength, maxObjectLength, - depth + 1, maxDepth, r) + currentNode.addChoice(dt, stats) } - sb.append("]") + + gen.setSubstringGen(cc => rootNode.makeGen(cc)) + rootNode.setStats(gen.substringGen, None) } } -object JSONObject extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - val length = r.nextInt(maxObjectLength) + 1 - sb.append("{") - (0 until length).foreach { i => - if (i > 0) { - sb.append(",") + +case class JSONStringGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val r = DataGen.getRandomFor(rowLoc, mapping) + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + // Value range is 32 (Space) to 126 (~) + buffer(at) = (r.nextInt(126 - 31) + 32).toByte + at += 1 + } + val strVal = new String(buffer, 0, len) + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\b", "\\b") + .replace("\f", "\\f") + '"' + strVal + '"' + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONStringGenFunc = + JSONStringGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONStringGenFunc = + JSONStringGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONStringGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONStringGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONLongGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = math.max(lengthGen(rowLoc), 1) // We need at least 1 long for a valid value + val r = DataGen.getRandomFor(rowLoc, mapping) + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + if (at == 0) { + // No leading 0's + buffer(at) = (r.nextInt(9) + '1').toByte + } else { + buffer(at) = (r.nextInt(10) + '0').toByte } - sb.append("\"key_") - sb.append(i) - sb.append("_") - sb.append(depth ) - sb.append("\":") - val childType = JSONType.selectType(depth, maxDepth, r) - childType.appendRandomValue(sb, i, maxStringLength, maxArrayLength, maxObjectLength, - depth + 1, maxDepth, r) + at += 1 } - sb.append("}") + new String(buffer, 0, len) } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONLongGenFunc = + JSONLongGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONLongGenFunc = + JSONLongGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") } -case class JSONGenFunc( - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - maxDepth: Int, - lengthGen: LengthGeneratorFunction = null, - mapping: LocationToSeedMapping = null) extends GeneratorFunction { +class JSONLongGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONLongGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONDoubleGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { override def apply(rowLoc: RowLocation): Any = { + val len = math.max(lengthGen(rowLoc), 3) // We have to have at least 3 chars NUM.NUM val r = DataGen.getRandomFor(rowLoc, mapping) - val sb = new StringBuilder() - JSONObject.appendRandomValue(sb, 0, maxStringLength, maxArrayLength, maxObjectLength, - 0, maxDepth, r) - // For now I am going to have some hard coded keys - UTF8String.fromString(sb.toString()) + val beforeLen = if (len == 3) { 1 } else { r.nextInt(len - 3) + 1 } + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + if (at == 0) { + // No leading 0's + buffer(at) = (r.nextInt(9) + '1').toByte + } else if (at == beforeLen) { + buffer(at) = '.' + } else { + buffer(at) = (r.nextInt(10) + '0').toByte + } + at += 1 + } + UTF8String.fromBytes(buffer, 0, len) } - override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = - JSONGenFunc(maxStringLength, maxArrayLength, maxObjectLength, maxDepth, lengthGen, mapping) + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONDoubleGenFunc = + JSONDoubleGenFunc(lengthGen, mapping) - override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = - JSONGenFunc(maxStringLength, maxArrayLength, maxObjectLength, maxDepth, lengthGen, mapping) + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONDoubleGenFunc = + JSONDoubleGenFunc(lengthGen, mapping) override def withValueRange(min: Any, max: Any): GeneratorFunction = - throw new IllegalArgumentException("value ranges are not supported for strings") + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONDoubleGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONDoubleGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONBoolGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val r = DataGen.getRandomFor(rowLoc, mapping) + val ret = if (r.nextBoolean()) "true" else "false" + UTF8String.fromString(ret) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONBoolGenFunc = + JSONBoolGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONBoolGenFunc = + JSONBoolGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONBoolGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONBoolGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONNullGenFunc(nullAsString: Boolean, + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = + if (nullAsString) { + UTF8String.fromString("null") + } else { + null + } + + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONNullGenFunc = + JSONNullGenFunc(nullAsString, lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONNullGenFunc = + JSONNullGenFunc(nullAsString, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONNullGen(nullAsString: Boolean, + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONNullGenFunc(nullAsString) + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONErrorGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val r = DataGen.getRandomFor(rowLoc, mapping) + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + // Value range is 32 (Space) to 126 (~) + // But it is almost impossible to show up as valid JSON + buffer(at) = (r.nextInt(126 - 31) + 32).toByte + at += 1 + } + UTF8String.fromBytes(buffer, 0, len) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONErrorGenFunc = + JSONErrorGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONErrorGenFunc = + JSONErrorGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONErrorGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONErrorGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONArrayGenFunc(child: GeneratorFunction, + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val data = new Array[String](len) + val childRowLoc = rowLoc.withNewChild() + var i = 0 + while (i < len) { + childRowLoc.setLastChildIndex(i) + val v = child(childRowLoc) + if (v == null) { + // A null in an array must look like "null" + data(i) = "null" + } else { + data(i) = v.toString + } + i += 1 + } + val ret = data.mkString("[", ",", "]") + UTF8String.fromString(ret) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONArrayGenFunc = + JSONArrayGenFunc(child, lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONArrayGenFunc = + JSONArrayGenFunc(child, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONArrayGen(child: SubstringDataGen, + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): SubstringDataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + child.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + this + } + + override protected def getValGen: GeneratorFunction = JSONArrayGenFunc(child.getGen) + + override def get(name: String): Option[SubstringDataGen] = { + if ("data".equalsIgnoreCase(name) || "child".equalsIgnoreCase(name)) { + Some(child) + } else { + None + } + } + + override def children: Seq[(String, SubstringDataGen)] = Seq(("data", child)) +} + +case class JSONObjectGenFunc(childGens: Array[(String, GeneratorFunction)], + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + override def apply(rowLoc: RowLocation): Any = { + // TODO randomize the order of the children??? + // TODO duplicate child values??? + // The row location does not change for a struct/object + val data = childGens.map { + case (k, gen) => + val key = k.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\b", "\\b") + .replace("\f", "\\f") + val v = gen.apply(rowLoc) + if (v == null) { + "" + } else { + '"' + key + "\":" + v + } + } + val ret = data.filterNot(_.isEmpty).mkString("{",",","}") + UTF8String.fromString(ret) + } + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONObjectGenFunc = + JSONObjectGenFunc(childGens, lengthGen, mapping) + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONObjectGenFunc = + JSONObjectGenFunc(childGens, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONObjectGen(val children: Seq[(String, SubstringDataGen)], + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): SubstringDataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + children.foreach { + case (_, gen) => + gen.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + } + this + } + + override def get(name: String): Option[SubstringDataGen] = + children.collectFirst { + case (childName, dataGen) if childName.equalsIgnoreCase(name) => dataGen + } + + override protected def getValGen: GeneratorFunction = { + val childGens = children.map(c => (c._1, c._2.getGen)).toArray + JSONObjectGenFunc(childGens) + } +} + +case class JSONChoiceGenFunc(choices: List[(Double, GeneratorFunction)], + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + override def apply(rowLoc: RowLocation): Any = { + val r = DataGen.getRandomFor(rowLoc, mapping) + val l = r.nextDouble() + var index = 0 + while (choices(index)._1 < l) { + index += 1 + } + val childRowLoc = rowLoc.withNewChild() + choices(index)._2(childRowLoc) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONChoiceGenFunc = + JSONChoiceGenFunc(choices, lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONChoiceGenFunc = + JSONChoiceGenFunc(choices, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONChoiceGen(val choices: Seq[(Double, String, SubstringDataGen)], + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override val children: Seq[(String, SubstringDataGen)] = + choices.map { case (_, name, gen) => (name, gen) } + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): SubstringDataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + children.foreach { + case (_, gen) => + gen.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + } + this + } + + override def get(name: String): Option[SubstringDataGen] = + children.collectFirst { + case (childName, dataGen) if childName.equalsIgnoreCase(name) => dataGen + } + + override protected def getValGen: GeneratorFunction = { + val childGens = choices.map(c => (c._1, c._3.getGen)).toList + JSONChoiceGenFunc(childGens) + } } case class ASCIIGenFunc( @@ -1672,14 +2451,46 @@ case class ASCIIGenFunc( throw new IllegalArgumentException("value ranges are not supported for strings") } -class StringGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) - extends DataGen(conf, defaultValueRange) { +/** + * This is here to wrap the substring gen function so that its length/settings + * are the ones used when generating a string, and not what was set for the string. + */ +case class SubstringGenFunc( + substringGen: GeneratorFunction, + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + substringGen(rowLoc) + } + + // The length and location seed mapping are just ignored for this... + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = + this + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + this + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for strings") +} + +class StringGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)], + var substringDataGen: Option[SubstringDataGen] = None) + extends DataGen(conf, defaultValueRange) { override def dataType: DataType = StringType - override protected def getValGen: GeneratorFunction = ASCIIGenFunc() + override protected def getValGen: GeneratorFunction = + substringDataGen.map(s => SubstringGenFunc(s.getGen)).getOrElse(ASCIIGenFunc()) override def children: Seq[(String, DataGen)] = Seq.empty + + override def setSubstringGen(subgen: Option[SubstringDataGen]): Unit = + substringDataGen = subgen + + override def getSubstringGen: Option[SubstringDataGen] = substringDataGen } case class StructGenFunc(childGens: Array[GeneratorFunction]) extends GeneratorFunction { @@ -1854,7 +2665,6 @@ class MapGen(key: DataGen, override def children: Seq[(String, DataGen)] = Seq(("key", key), ("value", value)) } - object ColumnGen { private def genInternal(rowNumber: Column, dataType: DataType, @@ -1869,8 +2679,8 @@ object ColumnGen { */ class ColumnGen(val dataGen: DataGen) { def setCorrelatedKeyGroup(kg: Long, - minSeed: Long, maxSeed: Long, - seedMapping: LocationToSeedMapping): ColumnGen = { + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): ColumnGen = { dataGen.setCorrelatedKeyGroup(kg, minSeed, maxSeed, seedMapping) this } @@ -1930,6 +2740,11 @@ class ColumnGen(val dataGen: DataGen) { this } + def setGaussianLength(mean: Double, stdDev: Double): ColumnGen = { + dataGen.setGaussianLength(mean, stdDev) + this + } + final def apply(name: String): DataGen = { get(name).getOrElse { throw new IllegalArgumentException(s"$name not a child of $this") @@ -1941,8 +2756,16 @@ class ColumnGen(val dataGen: DataGen) { def gen(rowNumber: Column): Column = { ColumnGen.genInternal(rowNumber, dataGen.dataType, dataGen.nullable, dataGen.getGen) } + + def getSubstring: Option[SubstringDataGen] = dataGen.getSubstringGen + + def substringGen: SubstringDataGen = dataGen.substringGen + + def setSubstringGen(f : ColumnConf => SubstringDataGen): Unit = + dataGen.setSubstringGen(f) } + sealed trait KeyGroupType /** @@ -2192,7 +3015,7 @@ object DBGen { numRows: Long, mapping: OrderedTypeMapping): Seq[(String, ColumnGen)] = { // a bit of a hack with the column num so that we update it before each time... - var conf = ColumnConf(ColumnLocation(tableId, -1), true, numRows) + var conf = ColumnConf(ColumnLocation(tableId, -1, 0), true, numRows) st.toArray.map { sf => if (!mapping.canMap(sf.dataType, mapping)) { throw new IllegalArgumentException(s"$sf is not supported at this time") diff --git a/dist/maven-antrun/build-parallel-worlds.xml b/dist/maven-antrun/build-parallel-worlds.xml index 524b15addf9..07838616340 100644 --- a/dist/maven-antrun/build-parallel-worlds.xml +++ b/dist/maven-antrun/build-parallel-worlds.xml @@ -1,6 +1,6 @@ - diff --git a/dist/scripts/binary-dedupe.sh b/dist/scripts/binary-dedupe.sh index 183e86b1524..356b0b4dbae 100755 --- a/dist/scripts/binary-dedupe.sh +++ b/dist/scripts/binary-dedupe.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,10 +34,10 @@ case "$OSTYPE" in esac STEP=0 -export SPARK3XX_COMMON_TXT="$PWD/spark3xx-common.txt" -export SPARK3XX_COMMON_COPY_LIST="$PWD/spark-common-copy-list.txt" +export SPARK_SHARED_TXT="$PWD/spark-shared.txt" +export SPARK_SHARED_COPY_LIST="$PWD/spark-shared-copy-list.txt" export DELETE_DUPLICATES_TXT="$PWD/delete-duplicates.txt" -export SPARK3XX_COMMON_DIR="$PWD/spark3xx-common" +export SPARK_SHARED_DIR="$PWD/spark-shared" # This script de-duplicates .class files at the binary level. # We could also diff classes using scalap / javap outputs. @@ -47,17 +47,17 @@ export SPARK3XX_COMMON_DIR="$PWD/spark3xx-common" # The following pipeline determines identical classes across shims in this build. # - checksum all class files -# - move the varying-prefix spark3xy to the left so it can be easily skipped for uniq and sort +# - move the varying-prefix sparkxyz to the left so it can be easily skipped for uniq and sort # - sort by path, secondary sort by checksum, print one line per group # - produce uniq count for paths # - filter the paths with count=1, the class files without diverging checksums -# - put the path starting with /spark3xy back together for the final list +# - put the path starting with /sparkxyz back together for the final list echo "Retrieving class files hashing to a single value ..." echo "$((++STEP))/ SHA1 of all non-META files > tmp-sha1-files.txt" -find ./parallel-world/spark3* -name META-INF -prune -o \( -type f -print \) | \ - xargs $SHASUM > tmp-sha1-files.txt +find ./parallel-world/spark[34]* -name META-INF -prune -o -name webapps -prune -o \( -type f -print0 \) | \ + xargs --null $SHASUM > tmp-sha1-files.txt echo "$((++STEP))/ make shim column 1 > tmp-shim-sha-package-files.txt" < tmp-sha1-files.txt awk -F/ '$1=$1' | \ @@ -68,10 +68,10 @@ echo "$((++STEP))/ sort by path, sha1; output first from each group > tmp-count- sort -k3 -k2,2 -u tmp-shim-sha-package-files.txt | \ uniq -f 2 -c > tmp-count-shim-sha-package-files.txt -echo "$((++STEP))/ files with unique sha1 > $SPARK3XX_COMMON_TXT" +echo "$((++STEP))/ files with unique sha1 > $SPARK_SHARED_TXT" grep '^\s\+1 .*' tmp-count-shim-sha-package-files.txt | \ awk '{$1=""; $3=""; print $0 }' | \ - tr -s ' ' | sed 's/\ /\//g' > "$SPARK3XX_COMMON_TXT" + tr -s ' ' | sed 's/\ /\//g' > "$SPARK_SHARED_TXT" function retain_single_copy() { set -e @@ -93,10 +93,10 @@ function retain_single_copy() { package_class="${package_class_with_spaces// //}" # get the reference copy out of the way - echo "$package_class" >> "from-$shim-to-spark3xx-common.txt" + echo "$package_class" >> "from-$shim-to-spark-shared.txt" # expanding directories separately because full path # glob is broken for class file name including the "$" character - for pw in ./parallel-world/spark3* ; do + for pw in ./parallel-world/spark[34]* ; do delete_path="$pw/$package_class" [[ -f "$delete_path" ]] && echo "$delete_path" || true done >> "$DELETE_DUPLICATES_TXT" || exit 255 @@ -106,26 +106,26 @@ function retain_single_copy() { # standalone debugging # truncate incremental files : > "$DELETE_DUPLICATES_TXT" -rm -f from-spark3*-to-spark3xx-common.txt -rm -rf "$SPARK3XX_COMMON_DIR" -mkdir -p "$SPARK3XX_COMMON_DIR" +rm -f from-spark[34]*-to-spark-shared.txt +rm -rf "$SPARK_SHARED_DIR" +mkdir -p "$SPARK_SHARED_DIR" -echo "$((++STEP))/ retaining a single copy of spark3xx-common classes" +echo "$((++STEP))/ retaining a single copy of spark-shared classes" while read spark_common_class; do retain_single_copy "$spark_common_class" -done < "$SPARK3XX_COMMON_TXT" +done < "$SPARK_SHARED_TXT" -echo "$((++STEP))/ rsyncing common classes to $SPARK3XX_COMMON_DIR" -for copy_list in from-spark3*-to-spark3xx-common.txt; do +echo "$((++STEP))/ rsyncing common classes to $SPARK_SHARED_DIR" +for copy_list in from-spark[34]*-to-spark-shared.txt; do echo Initializing rsync of "$copy_list" IFS='-' <<< "$copy_list" read -ra copy_list_parts # declare -p copy_list_parts shim="${copy_list_parts[1]}" # use rsync to reduce process forking - rsync --files-from="$copy_list" ./parallel-world/"$shim" "$SPARK3XX_COMMON_DIR" + rsync --files-from="$copy_list" ./parallel-world/"$shim" "$SPARK_SHARED_DIR" done -mv "$SPARK3XX_COMMON_DIR" parallel-world/ +mv "$SPARK_SHARED_DIR" parallel-world/ # TODO further dedupe by FEATURE version lines: # spark30x-common @@ -137,9 +137,9 @@ mv "$SPARK3XX_COMMON_DIR" parallel-world/ # # At this point the duplicate classes have not been removed from version-specific jar # locations such as parallel-world/spark312. -# For each unshimmed class file look for all of its copies inside /spark3* and +# For each unshimmed class file look for all of its copies inside /spark[34]* and # and count the number of distinct checksums. There are two representative cases -# 1) The class is contributed to the unshimmed location via the unshimmed-from-each-spark3xx list. These are classes +# 1) The class is contributed to the unshimmed location via the unshimmed-from-each-spark34 list. These are classes # carrying the shim classifier in their package name such as # com.nvidia.spark.rapids.spark312.RapidsShuffleManager. They are unique by construction, # and will have zero copies in any non-spark312 shims. Although such classes are currently excluded from @@ -157,25 +157,25 @@ mv "$SPARK3XX_COMMON_DIR" parallel-world/ # Determine the list of unshimmed class files UNSHIMMED_LIST_TXT=unshimmed-result.txt echo "$((++STEP))/ creating sorted list of unshimmed classes > $UNSHIMMED_LIST_TXT" -find ./parallel-world -name '*.class' -not -path './parallel-world/spark3*' | \ +find ./parallel-world -name '*.class' -not -path './parallel-world/spark[34-]*' | \ cut -d/ -f 3- | sort > "$UNSHIMMED_LIST_TXT" function verify_same_sha_for_unshimmed() { set -e class_file="$1" - # the raw spark3xx-common.txt file list contains all single-sha1 classes + # the raw spark-shared.txt file list contains all single-sha1 classes # including the ones that are unshimmed. Instead of expensively recomputing # sha1 look up if there is an entry with the unshimmed class as a suffix class_file_quoted=$(printf '%q' "$class_file") - # TODO currently RapidsShuffleManager is "removed" from /spark3* by construction in + # TODO currently RapidsShuffleManager is "removed" from /spark* by construction in # dist pom.xml via ant. We could delegate this logic to this script # and make both simmpler - if [[ ! "$class_file_quoted" =~ (com/nvidia/spark/rapids/spark3.*/.*ShuffleManager.class|org/apache/spark/sql/rapids/shims/spark3.*/ProxyRapidsShuffleInternalManager.class) ]]; then + if [[ ! "$class_file_quoted" =~ (com/nvidia/spark/rapids/spark[34].*/.*ShuffleManager.class|org/apache/spark/sql/rapids/shims/spark[34].*/ProxyRapidsShuffleInternalManager.class) ]]; then - if ! grep -q "/spark.\+/$class_file_quoted" "$SPARK3XX_COMMON_TXT"; then + if ! grep -q "/spark.\+/$class_file_quoted" "$SPARK_SHARED_TXT"; then echo >&2 "$class_file is not bitwise-identical across shims" exit 255 fi @@ -192,7 +192,7 @@ done < "$UNSHIMMED_LIST_TXT" echo "$((++STEP))/ removing duplicates of unshimmed classes" while read unshimmed_class; do - for pw in ./parallel-world/spark3* ; do + for pw in ./parallel-world/spark[34]* ; do unshimmed_path="$pw/$unshimmed_class" [[ -f "$unshimmed_path" ]] && echo "$unshimmed_path" || true done >> "$DELETE_DUPLICATES_TXT" diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index dec93e6f22a..18c26aa26e7 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -171,11 +171,16 @@ else TEST_TYPE_PARAM="--test_type $TEST_TYPE" fi + # We found that when parallelism > 8, as it increases, the test speed will become slower and slower. So we set the default maximum parallelism to 8. + # Note that MAX_PARALLEL varies with the hardware, OS, and test case. Please overwrite it with an appropriate value if needed. + MAX_PARALLEL=${MAX_PARALLEL:-8} if [[ ${TEST_PARALLEL} -lt 2 ]]; then # With xdist 0 and 1 are the same parallelism but # 0 is more efficient TEST_PARALLEL_OPTS=() + elif [[ ${TEST_PARALLEL} -gt ${MAX_PARALLEL} ]]; then + TEST_PARALLEL_OPTS=("-n" "$MAX_PARALLEL") else TEST_PARALLEL_OPTS=("-n" "$TEST_PARALLEL") fi @@ -245,6 +250,12 @@ else DRIVER_EXTRA_JAVA_OPTIONS="-ea -Duser.timezone=$TZ -Ddelta.log.cacheSize=$deltaCacheSize" export PYSP_TEST_spark_driver_extraJavaOptions="$DRIVER_EXTRA_JAVA_OPTIONS $COVERAGE_SUBMIT_FLAGS" export PYSP_TEST_spark_executor_extraJavaOptions="-ea -Duser.timezone=$TZ" + + # Set driver memory to speed up tests such as deltalake + if [[ -n "${DRIVER_MEMORY}" ]]; then + export PYSP_TEST_spark_driver_memory="${DRIVER_MEMORY}" + fi + export PYSP_TEST_spark_ui_showConsoleProgress='false' export PYSP_TEST_spark_sql_session_timeZone=$TZ export PYSP_TEST_spark_sql_shuffle_partitions='4' diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 2e6c36b77d9..fb1627af75b 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -159,7 +159,8 @@ def __repr__(self): return super().__repr__() + '(' + str(self._child_gen) + ')' def _cache_repr(self): - return super()._cache_repr() + '(' + self._child_gen._cache_repr() + ')' + return (super()._cache_repr() + '(' + self._child_gen._cache_repr() + + ',' + str(self._func.__code__) + ')' ) def start(self, rand): self._child_gen.start(rand) @@ -667,7 +668,10 @@ def __repr__(self): return super().__repr__() + '(' + str(self._child_gen) + ')' def _cache_repr(self): - return super()._cache_repr() + '(' + self._child_gen._cache_repr() + ')' + return (super()._cache_repr() + '(' + self._child_gen._cache_repr() + + ',' + str(self._min_length) + ',' + str(self._max_length) + ',' + + str(self.all_null) + ',' + str(self.convert_to_tuple) + ')') + def start(self, rand): self._child_gen.start(rand) @@ -701,7 +705,8 @@ def __repr__(self): return super().__repr__() + '(' + str(self._key_gen) + ',' + str(self._value_gen) + ')' def _cache_repr(self): - return super()._cache_repr() + '(' + self._key_gen._cache_repr() + ',' + self._value_gen._cache_repr() + ')' + return (super()._cache_repr() + '(' + self._key_gen._cache_repr() + ',' + self._value_gen._cache_repr() + + ',' + str(self._min_length) + ',' + str(self._max_length) + ')') def start(self, rand): self._key_gen.start(rand) @@ -769,12 +774,13 @@ def __init__(self, min_value=MIN_DAY_TIME_INTERVAL, max_value=MAX_DAY_TIME_INTER self._min_micros = (math.floor(min_value.total_seconds()) * 1000000) + min_value.microseconds self._max_micros = (math.floor(max_value.total_seconds()) * 1000000) + max_value.microseconds fields = ["day", "hour", "minute", "second"] - start_index = fields.index(start_field) - end_index = fields.index(end_field) - if start_index > end_index: + self._start_index = fields.index(start_field) + self._end_index = fields.index(end_field) + if self._start_index > self._end_index: raise RuntimeError('Start field {}, end field {}, valid fields is {}, start field index should <= end ' 'field index'.format(start_field, end_field, fields)) - super().__init__(DayTimeIntervalType(start_index, end_index), nullable=nullable, special_cases=special_cases) + super().__init__(DayTimeIntervalType(self._start_index, self._end_index), nullable=nullable, + special_cases=special_cases) def _gen_random(self, rand): micros = rand.randint(self._min_micros, self._max_micros) @@ -784,7 +790,8 @@ def _gen_random(self, rand): return timedelta(microseconds=micros) def _cache_repr(self): - return super()._cache_repr() + '(' + str(self._min_micros) + ',' + str(self._max_micros) + ')' + return (super()._cache_repr() + '(' + str(self._min_micros) + ',' + str(self._max_micros) + + ',' + str(self._start_index) + ',' + str(self._end_index) + ')') def start(self, rand): self._start(rand, lambda: self._gen_random(rand)) diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 99a2d4241e8..38dab9e84a4 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -224,6 +224,10 @@ def test_all_null_int96(spark_tmp_path): class AllNullTimestampGen(TimestampGen): def start(self, rand): self._start(rand, lambda : None) + + def _cache_repr(self): + return super()._cache_repr() + '(all_nulls)' + data_path = spark_tmp_path + '/PARQUET_DATA' confs = copy_and_update(writer_confs, {'spark.sql.parquet.outputTimestampType': 'INT96'}) assert_gpu_and_cpu_writes_are_equal_collect( diff --git a/jenkins/databricks/install_deps.py b/jenkins/databricks/install_deps.py index be5cb9bc040..8d21a4f9556 100644 --- a/jenkins/databricks/install_deps.py +++ b/jenkins/databricks/install_deps.py @@ -115,8 +115,10 @@ def define_deps(spark_version, scala_version): f'{prefix_ws_sp_mvn_hadoop}--org.json4s--json4s-jackson_{scala_version}--org.json4s__json4s-jackson_{scala_version}__*.jar'), Artifact('org.javaassist', 'javaassist', f'{prefix_ws_sp_mvn_hadoop}--org.javassist--javassist--org.javassist__javassist__*.jar'), - Artifact('com.fasterxml.jackson.core', 'jackson-core', + Artifact('com.fasterxml.jackson.core', 'jackson-databind', f'{prefix_ws_sp_mvn_hadoop}--com.fasterxml.jackson.core--jackson-databind--com.fasterxml.jackson.core__jackson-databind__*.jar'), + Artifact('com.fasterxml.jackson.core', 'jackson-core', + f'{prefix_ws_sp_mvn_hadoop}--com.fasterxml.jackson.core--jackson-core--com.fasterxml.jackson.core__jackson-core__*.jar'), Artifact('com.fasterxml.jackson.core', 'jackson-annotations', f'{prefix_ws_sp_mvn_hadoop}--com.fasterxml.jackson.core--jackson-annotations--com.fasterxml.jackson.core__jackson-annotations__*.jar'), Artifact('org.apache.spark', f'spark-avro_{scala_version}', diff --git a/jenkins/databricks/test.sh b/jenkins/databricks/test.sh index f71f69844f7..c966d5a92f7 100755 --- a/jenkins/databricks/test.sh +++ b/jenkins/databricks/test.sh @@ -66,9 +66,6 @@ TEST_MODE=${TEST_MODE:-'DEFAULT'} # --packages in distributed setups, should be fixed by # https://github.com/NVIDIA/spark-rapids/pull/5646 -# Increase driver memory as Delta Lake tests can slowdown with default 1G (possibly due to caching?) -DELTA_LAKE_CONFS="--driver-memory 2g" - # Enable event log for qualification & profiling tools testing export PYSP_TEST_spark_eventLog_enabled=true mkdir -p /tmp/spark-events @@ -105,7 +102,7 @@ if [[ "$(pwd)" == "$SOURCE_PATH" ]]; then if [[ "$TEST_MODE" == "DEFAULT" || $TEST_MODE == "CI_PART2" || "$TEST_MODE" == "DELTA_LAKE_ONLY" ]]; then ## Run Delta Lake tests - SPARK_SUBMIT_FLAGS="$SPARK_CONF $DELTA_LAKE_CONFS" TEST_PARALLEL=1 \ + DRIVER_MEMORY="4g" \ bash integration_tests/run_pyspark_from_build.sh --runtime_env="databricks" -m "delta_lake" --delta_lake --test_type=$TEST_TYPE fi diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh index 883b3f3acfc..697722c0138 100755 --- a/jenkins/spark-premerge-build.sh +++ b/jenkins/spark-premerge-build.sh @@ -78,7 +78,7 @@ mvn_verify() { # Here run Python integration tests tagged with 'premerge_ci_1' only, that would help balance test duration and memory # consumption from two k8s pods running in parallel, which executes 'mvn_verify()' and 'ci_2()' respectively. $MVN_CMD -B $MVN_URM_MIRROR $PREMERGE_PROFILES clean verify -Dpytest.TEST_TAGS="premerge_ci_1" \ - -Dpytest.TEST_TYPE="pre-commit" -Dpytest.TEST_PARALLEL=4 -Dcuda.version=$CLASSIFIER + -Dpytest.TEST_TYPE="pre-commit" -Dcuda.version=$CLASSIFIER # The jacoco coverage should have been collected, but because of how the shade plugin # works and jacoco we need to clean some things up so jacoco will only report for the @@ -162,7 +162,6 @@ ci_2() { $MVN_CMD -U -B $MVN_URM_MIRROR clean package $MVN_BUILD_ARGS -DskipTests=true export TEST_TAGS="not premerge_ci_1" export TEST_TYPE="pre-commit" - export TEST_PARALLEL=5 # Download a Scala 2.12 build of spark prepare_spark $SPARK_VER 2.12 @@ -206,7 +205,6 @@ ci_scala213() { cd .. # Run integration tests in the project root dir to leverage test cases and resource files export TEST_TAGS="not premerge_ci_1" export TEST_TYPE="pre-commit" - export TEST_PARALLEL=5 # SPARK_HOME (and related) must be set to a Spark built with Scala 2.13 SPARK_HOME=$SPARK_HOME PYTHONPATH=$PYTHONPATH \ ./integration_tests/run_pyspark_from_build.sh diff --git a/pom.xml b/pom.xml index 38f7e8c3812..06947857521 100644 --- a/pom.xml +++ b/pom.xml @@ -886,6 +886,7 @@ 340, 341, 342, + 343, 350, 351 diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index a6d7de172fe..cbc4aecbd26 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -886,6 +886,7 @@ 340, 341, 342, + 343, 350, 351 diff --git a/scala2.13/shim-deps/databricks/pom.xml b/scala2.13/shim-deps/databricks/pom.xml index b342f381c71..a0459901079 100644 --- a/scala2.13/shim-deps/databricks/pom.xml +++ b/scala2.13/shim-deps/databricks/pom.xml @@ -105,6 +105,12 @@ ${spark.version} compile + + com.fasterxml.jackson.core + jackson-databind + ${spark.version} + compile + com.fasterxml.jackson.core jackson-annotations @@ -286,4 +292,4 @@ compile - \ No newline at end of file + diff --git a/shim-deps/databricks/pom.xml b/shim-deps/databricks/pom.xml index bef8a90d227..22842b0f7c0 100644 --- a/shim-deps/databricks/pom.xml +++ b/shim-deps/databricks/pom.xml @@ -105,6 +105,12 @@ ${spark.version} compile + + com.fasterxml.jackson.core + jackson-databind + ${spark.version} + compile + com.fasterxml.jackson.core jackson-annotations @@ -286,4 +292,4 @@ compile - \ No newline at end of file + diff --git a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala index 36abc75ba87..2d7a51c4e43 100644 --- a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala +++ b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala @@ -40,19 +40,19 @@ import org.apache.spark.util.MutableURLClassLoader "parallel worlds" in the JDK's com.sun.istack.internal.tools.ParallelWorldClassLoader parlance 1. a few publicly documented classes in the conventional layout at the top 2. a large fraction of classes whose bytecode is identical under all supported Spark versions - in spark3xx-common + in spark-shared 3. a smaller fraction of classes that differ under one of the supported Spark versions com/nvidia/spark/SQLPlugin.class - spark3xx-common/com/nvidia/spark/rapids/CastExprMeta.class + spark-shared/com/nvidia/spark/rapids/CastExprMeta.class spark311/org/apache/spark/sql/rapids/GpuUnaryMinus.class spark320/org/apache/spark/sql/rapids/GpuUnaryMinus.class Each shim can see a consistent parallel world without conflicts by referencing only one conflicting directory. E.g., Spark 3.2.0 Shim will use - jar:file:/home/spark/rapids-4-spark_2.12-24.08.0.jar!/spark3xx-common/ + jar:file:/home/spark/rapids-4-spark_2.12-24.08.0.jar!/spark-shared/ jar:file:/home/spark/rapids-4-spark_2.12-24.08.0.jar!/spark320/ Spark 3.1.1 will use - jar:file:/home/spark/rapids-4-spark_2.12-24.08.0.jar!/spark3xx-common/ + jar:file:/home/spark/rapids-4-spark_2.12-24.08.0.jar!/spark-shared/ jar:file:/home/spark/rapids-4-spark_2.12-24.08.0.jar!/spark311/ Using these Jar URL's allows referencing different bytecode produced from identical sources by incompatible Scala / Spark dependencies. @@ -67,7 +67,7 @@ object ShimLoader extends Logging { new URL(rootUrlStr) } - private val shimCommonURL = new URL(s"${shimRootURL.toString}spark3xx-common/") + private val shimCommonURL = new URL(s"${shimRootURL.toString}spark-shared/") @volatile private var shimProviderClass: String = _ @volatile private var shimProvider: SparkShimServiceProvider = _ @volatile private var shimURL: URL = _ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 7f0a82517c3..41c2e5e3776 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -49,8 +49,8 @@ case class GpuConcat(children: Seq[Expression]) extends GpuComplexTypeMergingExp override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { val res = dataType match { - // Explicitly return null for empty concat as Spark, since cuDF doesn't support empty concat. - case dt if children.isEmpty => GpuScalar.from(null, dt) + // in Spark concat() will be considered as an empty string here + case dt if children.isEmpty => GpuScalar("", dt) // For single column concat, we pass the result of child node to avoid extra cuDF call. case _ if children.length == 1 => children.head.columnarEval(batch) case StringType => stringConcat(batch) diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala index ad93c4dd2e9..4cf155041d9 100644 --- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala @@ -72,7 +72,6 @@ class RapidsTestSettings extends BackendTestSettings { enableSuite[RapidsMathFunctionsSuite] enableSuite[RapidsRegexpExpressionsSuite] enableSuite[RapidsStringExpressionsSuite] - .exclude("concat", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) .exclude("string substring_index function", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) .exclude("SPARK-22498: Concat should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) .exclude("SPARK-22549: ConcatWs should not generate codes beyond 64KB", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))