diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 953597c8016d..b201d6952829 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -17195,7 +17195,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -17216,7 +17216,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -17238,7 +17238,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -17259,7 +17259,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -17328,7 +17328,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -17349,7 +17349,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -17371,7 +17371,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -17392,7 +17392,7 @@ are limited.
S |
NS |
NS |
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
diff --git a/integration_tests/DATA_GEN.md b/integration_tests/DATA_GEN.md
new file mode 100644
index 000000000000..3db42716df2b
--- /dev/null
+++ b/integration_tests/DATA_GEN.md
@@ -0,0 +1,444 @@
+# Big Data Generation
+
+In order to do scale testing we need a way to generate lots of data in a
+deterministic way that gives us control over the number of unique values
+in a column, the skew of the values in a column, and the correlation of
+data between tables for joins. To accomplish this we wrote
+`org.apache.spark.sql.tests.datagen`.
+
+## Setup Environment
+
+To get started with big data generation the first thing you need to do is
+to include the appropriate jar on the classpath for your version of Apache Spark.
+Note that this does not run on the GPU, but it does use
+parts of the shim framework that the RAPIDS Accelerator does, so it is currently in
+the integration tests jar for the RAPIDS Accelerator. The jar is specific to the
+version of Spark you are using and is not pushed to Maven Central. Because of this
+you will have to build it from source yourself.
+
+```shell
+cd integration_tests
+mvn clean package -Drat.skip -DskipTests -Dbuildver=$SPARK_VERSION
+```
+
+Where `$SPARK_VERSION` is a compressed version number, like 330 for Spark 3.3.0.
+
+If you are building with a jdk version that is not 8, you will need to add in the
+corresponding profile flag `-P`
+
+After this the jar should be at
+`target/rapids-4-spark-integration-tests_2.12-$PLUGIN_VERSION-spark$SPARK_VERSION.jar`
+for example a Spark 3.3.0 jar for the 23.08.0 release would be
+`target/rapids-4-spark-integration-tests_2.12-23.08.0-spark330.jar`
+
+To get a spark shell with this you can run
+```shell
+spark-shell --jars target/rapids-4-spark-integration-tests_2.12-23.08.0-spark330.jar
+```
+
+After that you should be good to go.
+
+## Generate Some Data
+
+The first thing to do is to import the classes
+
+```scala
+import org.apache.spark.sql.tests.datagen._
+```
+
+After this the main entry point is `DBGen`. `DBGen` provides a way to generate
+multiple tables, but we are going to start off with just one. To do this we can
+call `addTable` on it.
+
+```scala
+val dataTable = DBGen().addTable("data", "a string, b byte", 5)
+dataTable.toDF(spark).show()
++----------+----+
+| a| b|
++----------+----+
+|t=qIHOf:O)| 47|
+|yXT-j", 3)
+dataTable("a").setLength(1)
+dataTable("b").setLength(2)
+dataTable("b")("data").setLength(3)
+dataTable.toDF(spark).show(false)
++---+----------+
+|a |b |
++---+----------+
+|t |[X]6, /= minRow) {
+ null
+ } else {
+ gen(rowLoc)
+ }
+ }
+}
+
+...
+
+dataTable("a").setNullGen(MyNullGen(1024, 9999999L))
+```
+
+Similarly, if you have a requirement to generate JSON formatted strings that
+follow a given pattern you can do that. Or provide a distribution where a very
+specific seed shows up 99% of the time, and the rest of the time it falls back
+to the regular `FlatDistribution`, you can also do that. It is designed to be very
+flexible.
\ No newline at end of file
diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml
index 8b742a020afe..9cea03a0fd15 100644
--- a/integration_tests/pom.xml
+++ b/integration_tests/pom.xml
@@ -228,12 +228,24 @@
${spark.version}
provided
+
+ org.apache.spark
+ spark-unsafe_${scala.binary.version}
+ ${spark.version}
+ provided
+
org.scala-lang
scala-reflect
${scala.version}
provided
+
+ com.esotericsoftware.kryo
+ kryo-shaded-db
+ ${spark.version}
+ provided
+
org.apache.arrow
arrow-format
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index fd7ff4a4b494..ac96f02d01d2 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -688,7 +688,7 @@ def test_hash_groupby_collect_set(data_gen):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op, ids=idfn)
-@pytest.mark.xfail(condition=is_before_spark_330(), reason='https://github.com/NVIDIA/spark-rapids/issues/8716')
+@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/8716')
def test_hash_groupby_collect_set_on_nested_type(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
@@ -731,7 +731,7 @@ def test_hash_reduction_collect_set(data_gen):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op, ids=idfn)
-@pytest.mark.xfail(condition=is_before_spark_330(), reason='https://github.com/NVIDIA/spark-rapids/issues/8716')
+@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/8716')
def test_hash_reduction_collect_set_on_nested_type(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
@@ -1869,7 +1869,7 @@ def test_std_variance_partial_replace_fallback(data_gen,
conf=local_conf)
#
-# test min/max aggregations for structs
+# Test min/max aggregations on simple type (integer) keys and nested type values.
#
gens_for_max_min = [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
@@ -1877,16 +1877,11 @@ def test_std_variance_partial_replace_fallback(data_gen,
date_gen, timestamp_gen,
DecimalGen(precision=12, scale=2),
DecimalGen(precision=36, scale=5),
- null_gen] + array_gens_sample # + struct_gens_sample
-# Nested structs have issues, https://github.com/NVIDIA/spark-rapids/issues/8702
+ null_gen] + array_gens_sample + struct_gens_sample
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', gens_for_max_min, ids=idfn)
-def test_min_max_for_struct(data_gen):
- df_gen = [
- ('a', StructGen([
- ('aa', data_gen),
- ('ab', data_gen)])),
- ('b', RepeatSeqGen(IntegerGen(), length=20))]
+def test_min_max_in_groupby_and_reduction(data_gen):
+ df_gen = [('a', data_gen), ('b', RepeatSeqGen(IntegerGen(), length=20))]
# test max
assert_gpu_and_cpu_are_equal_sql(
diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py
index 8c426aa1a921..159e77d13306 100644
--- a/integration_tests/src/main/python/parquet_write_test.py
+++ b/integration_tests/src/main/python/parquet_write_test.py
@@ -730,20 +730,44 @@ def write_partitions(spark, table_path):
conf={}
)
-# Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693
-def test_hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path):
+def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func):
+ conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
+ 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase}
def create_table(spark, path):
tmp_table = spark_tmp_table_factory.get()
spark.sql(f"CREATE TABLE {tmp_table} STORED AS PARQUET " +
- f""" LOCATION '{path}' AS SELECT CAST('2015-01-01 00:00:00' AS TIMESTAMP) as t; """)
+ f""" LOCATION '{path}' AS SELECT CAST('2015-01-01 00:00:00' AS TIMESTAMP) as t; """)
def read_table(spark, path):
return spark.read.parquet(path)
data_path = spark_tmp_path + '/PARQUET_DATA'
- assert_gpu_and_cpu_writes_are_equal_collect(create_table, read_table, data_path)
- assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.parquet(data_path + '/CPU'))
+
+ func(create_table, read_table, data_path, conf)
+
+# Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693
+def test_hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path):
+
+ def func_test(create_table, read_table, data_path, conf):
+ assert_gpu_and_cpu_writes_are_equal_collect(create_table, read_table, data_path, conf=conf)
+ assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.parquet(data_path + '/CPU'))
+
+ hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, 'CORRECTED', func_test)
+
+# Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693
+@allow_non_gpu('DataWritingCommandExec', 'WriteFilesExec')
+def test_hive_timestamp_value_fallback(spark_tmp_table_factory, spark_tmp_path):
+
+ def func_test(create_table, read_table, data_path, conf):
+ assert_gpu_fallback_write(
+ create_table,
+ read_table,
+ data_path,
+ ['DataWritingCommandExec'],
+ conf)
+
+ hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, 'LEGACY', func_test)
@ignore_order
@pytest.mark.skipif(is_before_spark_340(), reason="`spark.sql.optimizer.plannedWrite.enabled` is only supported in Spark 340+")
diff --git a/integration_tests/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala b/integration_tests/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala
new file mode 100644
index 000000000000..27b47e7900c2
--- /dev/null
+++ b/integration_tests/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala
@@ -0,0 +1,1974 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.tests.datagen
+
+import java.math.{BigDecimal => JavaBigDecimal}
+import java.sql.{Date, Timestamp}
+import java.time.{Instant, LocalDate}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.math.BigDecimal.RoundingMode
+import scala.util.Random
+
+import org.apache.spark.sql.{Column, DataFrame, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, XXH64}
+import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Holds a representation of the current row location.
+ * @param rowNum the top level row number (starting at 0 but this might change)
+ * @param subRows a path to the child rows in an array starting at 0 for each array.
+ */
+class RowLocation(val rowNum: Long, val subRows: Array[Int] = null) {
+ def withNewChild(): RowLocation = {
+ val newSubRows = if (subRows == null) {
+ new Array[Int](1)
+ } else {
+ val tmp = new Array[Int](subRows.length + 1)
+ subRows.copyToArray(tmp, 0, subRows.length)
+ tmp
+ }
+ new RowLocation(rowNum, newSubRows)
+ }
+
+ def setLastChildIndex(index: Int): Unit =
+ subRows(subRows.length - 1) = index
+
+ /**
+ * Hash the location into a single long value.
+ */
+ def hashLoc(seed: Long): Long = {
+ var tmp = XXH64.hashLong(rowNum, seed)
+ if (subRows != null) {
+ var i = 0
+ while (i < subRows.length) {
+ tmp = XXH64.hashLong(subRows(i), tmp)
+ i += 1
+ }
+ }
+ tmp
+ }
+}
+
+/**
+ * Holds a representation of the current columns location. This is computed by walking
+ * the tree. The path is depth first, but as we pass a new node we assign it an ID as
+ * we go starting at 0. If this is a part of a correlated key group then the
+ * correlatedKeyGroup will be set and the columnNum will be ignored when computing the
+ * hash. This makes the generated data correlated for all column/child columns.
+ * @param tableNum a unique ID for the table this is a part of.
+ * @param columnNum the location of the column in the data being generated
+ * @param correlatedKeyGroup the correlated key group this column is a part of, if any.
+ */
+case class ColumnLocation(tableNum: Int, columnNum: Int, correlatedKeyGroup: Option[Long] = None) {
+ def forNextColumn(): ColumnLocation = ColumnLocation(tableNum, columnNum + 1)
+
+
+ /**
+ * Create a new ColumnLocation that is specifically for a given key group
+ */
+ def forCorrelatedKeyGroup(keyGroup: Long): ColumnLocation =
+ ColumnLocation(tableNum, columnNum, Some(keyGroup))
+
+ /**
+ * Hash the location into a single long value.
+ */
+ lazy val hashLoc: Long = XXH64.hashLong(tableNum, correlatedKeyGroup.getOrElse(columnNum))
+}
+
+/**
+ * Holds configuration for a given column, or sub-column.
+ * @param columnLoc the location of the column
+ * @param nullable are nulls supported in the output or not.
+ * @param numRows the number of rows that this table is going to have, not necessarily the number
+ * that the column is going to have.
+ * @param minSeed the minimum seed value allowed to be returned
+ * @param maxSeed the maximum seed value allowed to be returned
+ */
+case class ColumnConf(columnLoc: ColumnLocation,
+ nullable: Boolean,
+ numTableRows: Long,
+ minSeed: Long = Long.MinValue,
+ maxSeed: Long = Long.MaxValue) {
+
+ def forNextColumn(nullable: Boolean): ColumnConf =
+ ColumnConf(columnLoc.forNextColumn(), nullable, numTableRows)
+
+ /**
+ * Create a new configuration based on this, but for a given correlated key group.
+ */
+ def forCorrelatedKeyGroup(correlatedKeyGroup: Long): ColumnConf = {
+ ColumnConf(columnLoc.forCorrelatedKeyGroup(correlatedKeyGroup),
+ nullable,
+ numTableRows,
+ minSeed,
+ maxSeed)
+ }
+
+ /**
+ * Create a new configuration based on this, but for a given seed range.
+ */
+ def forSeedRange(min: Long, max: Long): ColumnConf =
+ ColumnConf(columnLoc, nullable, numTableRows, min, max)
+
+ /**
+ * Create a new configuration based on this, but for the null generator.
+ */
+ def forNulls: ColumnConf = {
+ assert(nullable, "Should not get a conf for nulls from a non-nullable column conf")
+ ColumnConf(columnLoc, nullable, numTableRows)
+ }
+}
+
+/**
+ * Provides a mapping between a location + configuration and a seed. This provides a random
+ * looking mapping between a location and a seed that will be used to deterministically
+ * generate data. This is also where we can change the distribution and cardinality of data.
+ * A min/max seed allows for changes to the cardinality of values generated and the exact
+ * mapping done can inject skew or various types of data distributions.
+ */
+trait LocationToSeedMapping extends Serializable {
+ /**
+ * Set the config for the column that should be used along with a min/max type seed range.
+ * This will be called before apply is ever called.
+ */
+ def withColumnConf(colConf: ColumnConf): LocationToSeedMapping
+
+ /**
+ * Given a row location + the previously set config produce a seed to be used.
+ * @param rowLoc the row location
+ * @return the seed to use
+ */
+ def apply(rowLoc: RowLocation): Long
+}
+
+object LocationToSeedMapping {
+ /**
+ * Return a function that should remap the full range of long seed
+ * (Long.MinValue to Long.MaxValue) to a new range as described by the conf in a way that
+ * does not really impact the distribution. This is not 100% correct
+ * because if the min/max is not the full range the mapping cannot always guarantee that
+ * each option gets exactly the same probabilty of being produced up.
+ * But it should be good enough.
+ */
+ def remapRangeFunc(colConf: ColumnConf): Long => Long = {
+ val minSeed = colConf.minSeed
+ val maxSeed = colConf.maxSeed
+ if (minSeed == Long.MinValue && maxSeed == Long.MaxValue) {
+ n => n
+ } else {
+ // We generate numbers between minSeed and maxSeed + 1, and then take the floor so that
+ // values on the ends have the same range as those in the middle
+ val scaleFactor = (BigDecimal(maxSeed) + 1 - minSeed) /
+ (BigDecimal(Long.MaxValue) - Long.MinValue)
+ val minBig = BigDecimal(Long.MinValue)
+ n => {
+ (((n - minBig) * scaleFactor) + minSeed).setScale(0, RoundingMode.FLOOR).toLong
+ }
+ }
+ }
+}
+
+/**
+ * The idea is that we have an address (`RowLocation`, `ColumnLocation`)
+ * that will uniquely identify the location of a piece of data to be generated.
+ * A `GeneratorFunction` should distinctly map this address + configs to a
+ * generated value. The column location is computed from the schema + configs
+ * and will be set by calling `setColumnConf` along with other configs before
+ * the function is used to generate any data. The row locations is created as needed
+ * and could be reused in between calls to `apply`. This should keep no
+ * state in between calls to `apply`.
+ *
+ * But we also want to have control over various aspects of how the data
+ * is generated. We want
+ * 1. The data to appear random (pseudo random)
+ * 2. The number of unique values (cardinality)
+ * 3. The distribution of those values (skew/etc)
+ *
+ * To accomplish this we also provide a location to seed mapping through the
+ * `setLocationToSeedMapping` API. This too is guaranteed to be called before
+ * `apply` is called.
+ */
+trait GeneratorFunction extends Serializable {
+ /**
+ * Set a location mapping function. This will be called before apply is ever called.
+ */
+ def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction
+
+ /**
+ * Set a new value range. This will be called if a value range is set by the user before
+ * apply is called. If this generator does not support a value range, then an exception
+ * should be thrown.
+ */
+ def withValueRange(min: Any, max: Any): GeneratorFunction
+
+ /**
+ * Set a LengthGeneratorFunction for this generator. Not all types need a length so this
+ * should only be overridden if it is needed. This will be called before apply is called.
+ */
+ def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = {
+ this
+ }
+
+ /**
+ * Actually generate the data. The data should be based off of the mapping and the row location
+ * passed in. The output data needs to correspond to something that Spark will be able
+ * to understand properly as the datatype this generator is for. Because we use an expression
+ * to call this, it needs to be one of the "internal" types that Spark expects.
+ */
+ def apply(rowLoc: RowLocation): Any
+}
+
+/**
+ * We want a way to be able to insert nulls as needed. We don't treat nulls
+ * the same as other values and they are generally off by default. This
+ * becomes a wrapper around another GeneratorFunction and decides
+ * if a null should be returned before calling into the other function or not.
+ * This will only be put into places where the schema allows it to be, and
+ * where the user has configured it to be.
+ */
+trait NullGeneratorFunction extends GeneratorFunction {
+ /**
+ * Create a new NullGeneratorFunction based on this one, but wrapping the passed in
+ * generator. This will be called before apply is called.
+ */
+ def withWrapped(gen: GeneratorFunction): NullGeneratorFunction
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction =
+ throw new IllegalArgumentException("Null Generators should not have a value range set for them")
+
+ /**
+ * Create a new NullGeneratorFunction bases on this one, but with the given mapping.
+ * This is guaranteed to be called before apply is called.
+ */
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): NullGeneratorFunction
+}
+
+/**
+ * Used to generate the length of a String, Array, Map, or any other variable length data.
+ * This is treated as a separate concern from the data generation because it is hard to
+ * not mess up the cardinality using a naive approach to generating a length. If you just
+ * use something like Random.nextInt(maxLen) then you can get skew because each length
+ * has a different maximum cardinality possible. This causes skew. For example if you want
+ * length between 0 and 1, then half of the rows generated will be length 0 and half
+ * will be length 1. So half the data will be a single value.
+ */
+trait LengthGeneratorFunction {
+ def withLocationToSeedMapping(mapping: LocationToSeedMapping): LengthGeneratorFunction
+ def apply(rowLoc: RowLocation): Int
+}
+
+/**
+ * Just generate all of the data with a single length. This is the simplest way to avoid
+ * skew because of the different possible cardinality for the different lengths.
+ */
+case class FixedLengthGeneratorFunction(length: Int) extends LengthGeneratorFunction {
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): LengthGeneratorFunction =
+ this
+
+ override def apply(rowLoc: RowLocation): Int = length
+}
+
+/**
+ * Generate nulls with a given probability.
+ * @param prob 0.0 to 1.0 for how often nulls should appear in the output.
+ */
+case class NullProbabilityGenerationFunction(prob: Double,
+ gen: GeneratorFunction = null,
+ mapping: LocationToSeedMapping = null) extends NullGeneratorFunction {
+
+ override def withWrapped(gen: GeneratorFunction): NullProbabilityGenerationFunction =
+ NullProbabilityGenerationFunction(prob, gen, mapping)
+
+ override def withLocationToSeedMapping(
+ mapping: LocationToSeedMapping): NullProbabilityGenerationFunction =
+ NullProbabilityGenerationFunction(prob, gen, mapping)
+
+ override def apply(rowLoc: RowLocation): Any = {
+ val r = DataGen.getRandomFor(rowLoc, mapping)
+ if (r.nextDouble() <= prob) {
+ null
+ } else {
+ gen(rowLoc)
+ }
+ }
+}
+
+/**
+ * A LocationToSeedMapping that generates seeds with an approximately equal chance for
+ * all values.
+ */
+case class FlatDistribution(colLocSeed: Long = 0L,
+ remapRangeFunc: Long => Long = n => n) extends LocationToSeedMapping {
+
+ override def withColumnConf(colConf: ColumnConf): FlatDistribution = {
+ val colLocSeed = colConf.columnLoc.hashLoc
+ val remapRangeFunc = LocationToSeedMapping.remapRangeFunc(colConf)
+ FlatDistribution(colLocSeed, remapRangeFunc)
+ }
+
+ override def apply(rowLoc: RowLocation): Long =
+ remapRangeFunc(rowLoc.hashLoc(colLocSeed))
+}
+
+/**
+ * A LocationToSeedMapping that generates a unique seed per row. The order in which the values are
+ * generated is some what of a random like order. This should *NOT* be applied to any
+ * columns that are under an array because it ues rowNum and assumes that this maps correctly
+ * to the number of rows being generated.
+ */
+case class DistinctDistribution(numTableRows: Long = Long.MaxValue,
+ columnLocSeed: Long = 0L,
+ minSeed: Long = Long.MinValue,
+ maxSeed: Long = Long.MaxValue) extends LocationToSeedMapping {
+
+ override def withColumnConf(colConf: ColumnConf): DistinctDistribution = {
+ val numTableRows = colConf.numTableRows
+ val maskLen = 64 - java.lang.Long.numberOfLeadingZeros(numTableRows)
+ val maxMask = (1 << maskLen) - 1
+ val columnLocSeed = colConf.columnLoc.hashLoc & maxMask
+ val minSeed = colConf.minSeed
+ val maxSeed = colConf.maxSeed
+ DistinctDistribution(numTableRows, columnLocSeed, minSeed, maxSeed)
+ }
+
+ override def apply(rowLoc: RowLocation): Long = {
+ assert(rowLoc.subRows == null,
+ "DistinctDistribution cannot be applied to columns under an Array")
+ val range = BigInt(maxSeed) - minSeed
+ val modded = BigInt(rowLoc.rowNum).mod(range)
+ val rowSeed = (modded + minSeed).toLong
+ val ret = rowSeed ^ columnLocSeed
+ if (ret > numTableRows) {
+ rowSeed
+ } else {
+ ret
+ }
+ }
+}
+
+object DataGen {
+ val rLocal = new ThreadLocal[Random] {
+ override def initialValue(): Random = new XORShiftRandom()
+ }
+
+ /**
+ * Get a Random instance that is based off of a given seed. This should not be kept in
+ * between calls to apply, and should not be used after calling a child methods apply.
+ */
+ def getRandomFor(locSeed: Long): Random = {
+ val r = rLocal.get()
+ r.setSeed(locSeed)
+ r
+ }
+
+ /**
+ * Get a Random instance that is based off of the given location and mapping.
+ */
+ def getRandomFor(rowLoc: RowLocation, mapping: LocationToSeedMapping): Random =
+ getRandomFor(mapping(rowLoc))
+
+ /**
+ * Get a long value that is mapped from the given location and mapping.
+ */
+ def nextLong(rowLoc: RowLocation, mapping: LocationToSeedMapping): Long =
+ getRandomFor(rowLoc, mapping).nextLong()
+
+ /**
+ * Produce a remapping function that remaps a full range of long values to a new range evenly.
+ */
+ def remapRangeFunc(typeMinVal: Long, typeMaxVal: Long): Long => Long = {
+ if (typeMinVal == Long.MinValue && typeMaxVal == Long.MaxValue) {
+ n => n
+ } else {
+ // We generate numbers between typeMinVal and typeMaxVal + 1, and then take the floor so that
+ // all numbers get an even chance.
+ val scaleFactor = (BigDecimal(typeMaxVal) + 1 - typeMinVal) /
+ (BigDecimal(Long.MaxValue) - Long.MinValue)
+ val minBig = BigDecimal(Long.MinValue)
+ n => {
+ (((n - minBig) * scaleFactor) + typeMinVal).setScale(0, RoundingMode.FLOOR).toLong
+ }
+ }
+ }
+}
+
+/**
+ * An expression that actually does the data generation based on an input row number.
+ */
+case class DataGenExpr(child: Expression,
+ override val dataType: DataType,
+ canHaveNulls: Boolean,
+ f: GeneratorFunction) extends DataGenExprBase {
+
+ override def nullable: Boolean = canHaveNulls
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(LongType)
+
+ override def eval(input: InternalRow): Any = {
+ val rowLoc = new RowLocation(child.eval(input).asInstanceOf[Long])
+ f(rowLoc)
+ }
+}
+
+/**
+ * Base class for generating a column/sub-column. This holds configuration for the column,
+ * and handles what is needed to convert it into GeneratorFunction
+ */
+abstract class DataGen(var conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)],
+ var seedMapping: LocationToSeedMapping = FlatDistribution(),
+ var nullMapping: LocationToSeedMapping = FlatDistribution(),
+ var lengthGen: LengthGeneratorFunction = FixedLengthGeneratorFunction(10)) {
+ protected var userProvidedValueGen: Option[GeneratorFunction] = None
+ protected var userProvidedNullGen: Option[NullGeneratorFunction] = None
+ protected var valueRange: Option[(Any, Any)] = defaultValueRange
+
+ /**
+ * Set a value range for this data gen.
+ */
+ def setValueRange(min: Any, max: Any): DataGen = {
+ valueRange = Some((min, max))
+ this
+ }
+
+ /**
+ * Set a custom GeneratorFunction to use for this column.
+ */
+ def setValueGen(f: GeneratorFunction): DataGen = {
+ userProvidedValueGen = Some(f)
+ this
+ }
+
+ /**
+ * Set a NullGeneratorFunction for this column. This will not be used
+ * if the column is not nullable.
+ */
+ def setNullGen(f: NullGeneratorFunction): DataGen = {
+ this.userProvidedNullGen = Some(f)
+ this
+ }
+
+ /**
+ * Set the probability of a null appearing in the output. The probability should be
+ * 0.0 to 1.0.
+ */
+ def setNullProbability(probability: Double): DataGen = {
+ this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability))
+ this
+ }
+
+ /**
+ * Set a specific location to seed mapping for the value generation.
+ */
+ def setSeedMapping(seedMapping: LocationToSeedMapping): DataGen = {
+ this.seedMapping = seedMapping
+ this
+ }
+
+ /**
+ * Set a specific location to seed mapping for the null generation.
+ */
+ def setNullMapping(nullMapping: LocationToSeedMapping): DataGen = {
+ this.nullMapping = nullMapping
+ this
+ }
+
+ /**
+ * Set a specific LengthGeneratorFunction to use. This will only be used if
+ * the datatype needs a length.
+ */
+ def setLengthGen(lengthGen: LengthGeneratorFunction): DataGen = {
+ this.lengthGen = lengthGen
+ this
+ }
+
+ /**
+ * Set the length generation to be a fixed length.
+ */
+ def setLength(len: Int): DataGen = {
+ this.lengthGen = FixedLengthGeneratorFunction(len)
+ this
+ }
+
+ /**
+ * Add this column to a specific correlated key group. This should not be
+ * called directly by users.
+ */
+ def setCorrelatedKeyGroup(keyGroup: Long,
+ minSeed: Long, maxSeed: Long,
+ seedMapping: LocationToSeedMapping): DataGen = {
+ conf = conf.forCorrelatedKeyGroup(keyGroup)
+ .forSeedRange(minSeed, maxSeed)
+ this.seedMapping = seedMapping
+ this
+ }
+
+ /**
+ * Set a range of seed values that should be returned by the LocationToSeedMapping
+ */
+ def setSeedRange(min: Long, max: Long): DataGen = {
+ conf = conf.forSeedRange(min, max)
+ this
+ }
+
+ /**
+ * Get the default value generator for this specific data gen.
+ */
+ protected def getValGen: GeneratorFunction
+
+ /**
+ * Get the final ready to use GeneratorFunction for the data generator.
+ */
+ def getGen: GeneratorFunction = {
+ val sm = seedMapping.withColumnConf(conf)
+ val lg = lengthGen.withLocationToSeedMapping(sm)
+ var valGen = userProvidedValueGen.getOrElse(getValGen)
+ .withLocationToSeedMapping(sm)
+ .withLengthGeneratorFunction(lg)
+ valueRange.foreach {
+ case (min, max) =>
+ valGen = valGen.withValueRange(min, max)
+ }
+ if (nullable && userProvidedNullGen.isDefined) {
+ val nullColConf = conf.forNulls
+ val nm = nullMapping.withColumnConf(nullColConf)
+ userProvidedNullGen.get
+ .withWrapped(valGen)
+ .withLocationToSeedMapping(nm)
+ } else {
+ valGen
+ }
+ }
+
+ /**
+ * Get the data type for this column
+ */
+ def dataType: DataType
+
+ /**
+ * Is this column nullable or not.
+ */
+ def nullable: Boolean = conf.nullable
+
+ /**
+ * Get a child column for a given name, if it has one.
+ */
+ final def apply(name: String): DataGen = {
+ get(name).getOrElse{
+ throw new IllegalStateException(s"Could not find a child $name for $this")
+ }
+ }
+
+ def get(name: String): Option[DataGen] = None
+}
+
+/**
+ * A Value generator for a single value
+ * @param v the value to return
+ */
+case class SingleValGenFunc(v: Any) extends GeneratorFunction {
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): SingleValGenFunc = {
+ this
+ }
+ override def apply(rowLoc: RowLocation): Any = v
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ if (min != max && min != v) {
+ throw new IllegalArgumentException(s"The only value supported for this range is $v")
+ }
+ this
+ }
+}
+
+/**
+ * Pick a value out of the list of possible values and return that.
+ */
+case class EnumValGenFunc(values: Array[_ <: Any],
+ mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+ require(values != null && values.length > 0)
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): EnumValGenFunc = {
+ EnumValGenFunc(values, mapping)
+ }
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction =
+ throw new IllegalArgumentException(s"value ranges are not supported for EnumValueGenFunc")
+
+ override def apply(rowLoc: RowLocation): Any = {
+ val r = DataGen.getRandomFor(rowLoc, mapping)
+ values(r.nextInt(values.length))
+ }
+}
+
+/**
+ * A value generator for booleans.
+ */
+case class BooleanGenFunc(mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): BooleanGenFunc =
+ BooleanGenFunc(mapping)
+
+ override def apply(rowLoc: RowLocation): Any =
+ DataGen.getRandomFor(rowLoc, mapping).nextBoolean()
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ val minb = BooleanGen.asBoolean(min)
+ val maxb = BooleanGen.asBoolean(max)
+ if (minb == maxb) {
+ SingleValGenFunc(minb)
+ } else {
+ this
+ }
+ }
+}
+
+object BooleanGen {
+ def asBoolean(a: Any): Boolean = a match {
+ case b: Boolean => b
+ case other =>
+ throw new IllegalArgumentException(s"a boolean value range only supports " +
+ s"boolean values found $other")
+ }
+}
+
+class BooleanGen(conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+
+ override def dataType: DataType = BooleanType
+
+ override protected def getValGen: GeneratorFunction = BooleanGenFunc()
+}
+
+/**
+ * A value generator for Bytes
+ */
+case class ByteGenFunc(mapping: LocationToSeedMapping = null,
+ min: Byte = Byte.MinValue,
+ max: Byte = Byte.MaxValue) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => Byte = ByteGen.remapRangeFunc(min, max)
+
+ override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping))
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): ByteGenFunc =
+ ByteGenFunc(mapping, min, max)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var bmin = ByteGen.asByte(min)
+ var bmax = ByteGen.asByte(max)
+ if (bmin > bmax) {
+ val tmp = bmin
+ bmin = bmax
+ bmax = tmp
+ }
+
+ if (bmin == bmax) {
+ SingleValGenFunc(bmin)
+ } else {
+ ByteGenFunc(mapping, bmin, bmax)
+ }
+ }
+}
+
+object ByteGen {
+ def asByte(a: Any): Byte = a match {
+ case n: Byte => n
+ case n: Short if n == n.toByte => n.toByte
+ case n: Int if n == n.toByte => n.toByte
+ case n: Long if n == n.toByte => n.toByte
+ case other =>
+ throw new IllegalArgumentException(s"a byte value range only supports " +
+ s"byte values found $other")
+ }
+
+ def remapRangeFunc(min: Byte, max: Byte): Long => Byte = {
+ val wrapped = DataGen.remapRangeFunc(min, max)
+ n => wrapped(n).toByte
+ }
+}
+
+class ByteGen(conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+ override def getValGen: GeneratorFunction = ByteGenFunc()
+ override def dataType: DataType = ByteType
+}
+
+/**
+ * A value generator for Shorts
+ */
+case class ShortGenFunc(mapping: LocationToSeedMapping = null,
+ min: Short = Short.MinValue,
+ max: Short = Short.MaxValue) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => Short =
+ ShortGen.remapRangeFunc(min, max)
+
+ override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping))
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): ShortGenFunc =
+ ShortGenFunc(mapping, min, max)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var smin = ShortGen.asShort(min)
+ var smax = ShortGen.asShort(max)
+ if (smin > smax) {
+ val tmp = smin
+ smin = smax
+ smax = tmp
+ }
+
+ if (smin == smax) {
+ SingleValGenFunc(smin)
+ } else {
+ ShortGenFunc(mapping, smin, smax)
+ }
+ }
+}
+
+object ShortGen {
+ def asShort(a: Any): Short = a match {
+ case n: Byte => n.toShort
+ case n: Short => n
+ case n: Int if n == n.toShort => n.toShort
+ case n: Long if n == n.toShort => n.toShort
+ case other =>
+ throw new IllegalArgumentException(s"a short value range only supports " +
+ s"short values found $other")
+ }
+
+ def remapRangeFunc(min: Short, max: Short): Long => Short = {
+ val wrapped = DataGen.remapRangeFunc(min, max)
+ n => wrapped(n).toShort
+ }
+}
+
+class ShortGen(conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+ override def getValGen: GeneratorFunction = ShortGenFunc()
+
+ override def dataType: DataType = ShortType
+}
+
+/**
+ * A value generator for Ints
+ */
+case class IntGenFunc(mapping: LocationToSeedMapping = null,
+ min: Int = Int.MinValue,
+ max: Int = Int.MaxValue) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => Int = IntGen.remapRangeFunc(min, max)
+
+ override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping))
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): IntGenFunc =
+ IntGenFunc(mapping, min, max)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var imin = IntGen.asInt(min)
+ var imax = IntGen.asInt(max)
+ if (imin > imax) {
+ val tmp = imin
+ imin = imax
+ imax = tmp
+ }
+
+ if (imin == imax) {
+ SingleValGenFunc(imin)
+ } else {
+ IntGenFunc(mapping, imin, imax)
+ }
+ }
+}
+
+object IntGen {
+ def asInt(a: Any): Int = a match {
+ case n: Byte => n.toInt
+ case n: Short => n.toInt
+ case n: Int => n
+ case n: Long if n == n.toInt => n.toInt
+ case other =>
+ throw new IllegalArgumentException(s"a int value range only supports " +
+ s"int values found $other")
+ }
+
+ def remapRangeFunc(min: Int, max: Int): Long => Int = {
+ val wrapped = DataGen.remapRangeFunc(min, max)
+ n => wrapped(n).toInt
+ }
+}
+
+class IntGen(conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+ override def getValGen: GeneratorFunction = IntGenFunc()
+
+ override def dataType: DataType = IntegerType
+}
+
+/**
+ * A value generator for Longs
+ */
+case class LongGenFunc(mapping: LocationToSeedMapping = null,
+ min: Long = Long.MinValue,
+ max: Long = Long.MaxValue) extends GeneratorFunction {
+
+ protected lazy val valueRemapping: Long => Long = LongGen.remapRangeFunc(min, max)
+
+ override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping))
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): LongGenFunc =
+ LongGenFunc(mapping, min, max)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var lmin = LongGen.asLong(min)
+ var lmax = LongGen.asLong(max)
+ if (lmin > lmax) {
+ val tmp = lmin
+ lmin = lmax
+ lmax = tmp
+ }
+
+ if (lmin == lmax) {
+ SingleValGenFunc(lmin)
+ } else {
+ LongGenFunc(mapping, lmin, lmax)
+ }
+ }
+}
+
+object LongGen {
+ def asLong(a: Any): Long = a match {
+ case n: Byte => n.toLong
+ case n: Short => n.toLong
+ case n: Int => n.toLong
+ case n: Long => n
+ case other =>
+ throw new IllegalArgumentException(s"a long value range only supports " +
+ s"long values found $other")
+ }
+
+ def remapRangeFunc(min: Long, max: Long): Long => Long =
+ DataGen.remapRangeFunc(min, max)
+}
+
+class LongGen(conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+ override def getValGen: GeneratorFunction = LongGenFunc()
+
+ override def dataType: DataType = LongType
+}
+
+case class Decimal32GenFunc(
+ precision: Int,
+ scale: Int,
+ unscaledMin: Int,
+ unscaledMax: Int,
+ mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => Int = IntGen.remapRangeFunc(unscaledMin, unscaledMax)
+
+ override def apply(rowLoc: RowLocation): Any =
+ Decimal(valueRemapping(DataGen.nextLong(rowLoc, mapping)), precision, scale)
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction =
+ Decimal32GenFunc(precision, scale, unscaledMin, unscaledMax, mapping)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var imin = DecimalGen.asUnscaledInt(min, precision, scale)
+ var imax = DecimalGen.asUnscaledInt(max, precision, scale)
+ if (imin > imax) {
+ val tmp = imin
+ imin = imax
+ imax = tmp
+ }
+
+ if (imin == imax) {
+ SingleValGenFunc(Decimal(imin, precision, scale))
+ } else {
+ Decimal32GenFunc(precision, scale, imin, imax, mapping)
+ }
+ }
+}
+
+case class Decimal64GenFunc(
+ precision: Int,
+ scale: Int,
+ unscaledMin: Long,
+ unscaledMax: Long,
+ mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => Long = LongGen.remapRangeFunc(unscaledMin, unscaledMax)
+
+ override def apply(rowLoc: RowLocation): Any =
+ Decimal(valueRemapping(DataGen.nextLong(rowLoc, mapping)), precision, scale)
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction =
+ Decimal64GenFunc(precision, scale, unscaledMin, unscaledMax, mapping)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var lmin = DecimalGen.asUnscaledLong(min, precision, scale)
+ var lmax = DecimalGen.asUnscaledLong(max, precision, scale)
+ if (lmin > lmax) {
+ val tmp = lmin
+ lmin = lmax
+ lmax = tmp
+ }
+
+ if (lmin == lmax) {
+ SingleValGenFunc(Decimal(lmin, precision, scale))
+ } else {
+ Decimal64GenFunc(precision, scale, lmin, lmax, mapping)
+ }
+ }
+}
+
+case class DecimalGenFunc(
+ precision: Int,
+ scale: Int,
+ unscaledMin: BigInt,
+ unscaledMax: BigInt,
+ mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => BigInt =
+ DecimalGen.remapRangeFunc(unscaledMin, unscaledMax)
+
+ override def apply(rowLoc: RowLocation): Any = {
+ val bi = valueRemapping(DataGen.nextLong(rowLoc, mapping)).bigInteger
+ Decimal(new JavaBigDecimal(bi, scale))
+ }
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction =
+ DecimalGenFunc(precision, scale, unscaledMin, unscaledMax, mapping)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var lmin = DecimalGen.asUnscaled(min, precision, scale)
+ var lmax = DecimalGen.asUnscaled(max, precision, scale)
+ if (lmin > lmax) {
+ val tmp = lmin
+ lmin = lmax
+ lmax = tmp
+ }
+
+ if (lmin == lmax) {
+ SingleValGenFunc(Decimal(new JavaBigDecimal(lmin.bigInteger, scale)))
+ } else {
+ DecimalGenFunc(precision, scale, lmin, lmax, mapping)
+ }
+ }
+}
+
+object DecimalGen {
+ def toUnscaledInt(d: JavaBigDecimal, precision: Int, scale: Int): Int = {
+ val tmp = d.setScale(scale)
+ require(tmp.precision() <= precision, "The value is not in the supported precision range")
+ tmp.unscaledValue().intValueExact()
+ }
+
+ def asUnscaledInt(a: Any, precision: Int, scale: Int): Int = a match {
+ case d: BigDecimal =>
+ val tmp = d.setScale(scale)
+ toUnscaledInt(tmp.bigDecimal, precision, scale)
+ case d: JavaBigDecimal =>
+ toUnscaledInt(d, precision, scale)
+ case d: Decimal =>
+ toUnscaledInt(d.toJavaBigDecimal, precision, scale)
+ case n: Byte =>
+ toUnscaledInt(new JavaBigDecimal(n), precision, scale)
+ case n: Short =>
+ toUnscaledInt(new JavaBigDecimal(n), precision, scale)
+ case n: Int =>
+ toUnscaledInt(new JavaBigDecimal(n), precision, scale)
+ case n: Long =>
+ toUnscaledInt(new JavaBigDecimal(n), precision, scale)
+ case n: Float =>
+ toUnscaledInt(new JavaBigDecimal(n), precision, scale)
+ case n: Double =>
+ toUnscaledInt(new JavaBigDecimal(n), precision, scale)
+ case other =>
+ throw new IllegalArgumentException(s"a decimal 32 value range only supports " +
+ s"decimal values found $other")
+ }
+
+ def toUnscaledLong(d: JavaBigDecimal, precision: Int, scale: Int): Long = {
+ val tmp = d.setScale(scale)
+ require(tmp.precision() <= precision, "The value is not in the supported precision range")
+ tmp.unscaledValue().longValueExact()
+ }
+
+ def asUnscaledLong(a: Any, precision: Int, scale: Int): Long = a match {
+ case d: BigDecimal =>
+ val tmp = d.setScale(scale)
+ toUnscaledLong(tmp.bigDecimal, precision, scale)
+ case d: JavaBigDecimal =>
+ toUnscaledLong(d, precision, scale)
+ case d: Decimal =>
+ toUnscaledLong(d.toJavaBigDecimal, precision, scale)
+ case n: Byte =>
+ toUnscaledLong(new JavaBigDecimal(n), precision, scale)
+ case n: Short =>
+ toUnscaledLong(new JavaBigDecimal(n), precision, scale)
+ case n: Int =>
+ toUnscaledLong(new JavaBigDecimal(n), precision, scale)
+ case n: Long =>
+ toUnscaledLong(new JavaBigDecimal(n), precision, scale)
+ case n: Float =>
+ toUnscaledLong(new JavaBigDecimal(n), precision, scale)
+ case n: Double =>
+ toUnscaledLong(new JavaBigDecimal(n), precision, scale)
+ case other =>
+ throw new IllegalArgumentException(s"a decimal 64 value range only supports " +
+ s"decimal values found $other")
+ }
+
+ def toUnscaled(d: JavaBigDecimal, precision: Int, scale: Int): BigInt = {
+ val tmp = d.setScale(scale)
+ require(tmp.precision() <= precision, "The value is not in the supported precision range")
+ tmp.unscaledValue()
+ }
+
+ def asUnscaled(a: Any, precision: Int, scale: Int): BigInt = a match {
+ case d: BigDecimal =>
+ val tmp = d.setScale(scale)
+ toUnscaled(tmp.bigDecimal, precision, scale)
+ case d: JavaBigDecimal =>
+ toUnscaled(d, precision, scale)
+ case d: Decimal =>
+ toUnscaled(d.toJavaBigDecimal, precision, scale)
+ case n: Byte =>
+ toUnscaled(new JavaBigDecimal(n), precision, scale)
+ case n: Short =>
+ toUnscaled(new JavaBigDecimal(n), precision, scale)
+ case n: Int =>
+ toUnscaled(new JavaBigDecimal(n), precision, scale)
+ case n: Long =>
+ toUnscaled(new JavaBigDecimal(n), precision, scale)
+ case n: Float =>
+ toUnscaled(new JavaBigDecimal(n), precision, scale)
+ case n: Double =>
+ toUnscaled(new JavaBigDecimal(n), precision, scale)
+ case other =>
+ throw new IllegalArgumentException(s"a decimal value range only supports " +
+ s"decimal values found $other")
+ }
+
+ def genMaxUnscaledInt(precision: Int): Int = {
+ require(precision >= 0 && precision <= Decimal.MAX_INT_DIGITS)
+ genMaxUnscaled(precision).toInt
+ }
+
+ def genMaxUnscaledLong(precision: Int): Long = {
+ require(precision > Decimal.MAX_INT_DIGITS && precision <= Decimal.MAX_LONG_DIGITS)
+ genMaxUnscaled(precision).toLong
+ }
+
+ def genMaxUnscaled(precision: Int): BigInt =
+ BigInt(10).pow(precision) - 1
+
+ def remapRangeFunc(minVal: BigInt, maxVal: BigInt): Long => BigInt = {
+ // We generate numbers between minVal and maxVal + 1, and then take the floor so that
+ // all numbers get an even chance.
+ val minValBD = BigDecimal(minVal)
+ val scaleFactor = (BigDecimal(maxVal) + 1 - minValBD) /
+ (BigDecimal(Long.MaxValue) - Long.MinValue)
+ val minLong = BigDecimal(Long.MinValue)
+ n => {
+ (((n - minLong) * scaleFactor) + minValBD).setScale(0, RoundingMode.FLOOR).toBigInt()
+ }
+ }
+}
+
+class DecimalGen(dt: DecimalType,
+ conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+
+ override def dataType: DataType = dt
+
+ override protected def getValGen: GeneratorFunction =
+ if (dt.precision <= Decimal.MAX_INT_DIGITS) {
+ val max = DecimalGen.genMaxUnscaledInt(dt.precision)
+ Decimal32GenFunc(dt.precision, dt.scale, max, -max)
+ } else if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
+ val max = DecimalGen.genMaxUnscaledLong(dt.precision)
+ Decimal64GenFunc(dt.precision, dt.scale, max, -max)
+ } else {
+ val max = DecimalGen.genMaxUnscaled(dt.precision)
+ DecimalGenFunc(dt.precision, dt.scale, -max, max)
+ }
+}
+
+/**
+ * A value generator for Timestamps
+ */
+case class TimestampGenFunc(mapping: LocationToSeedMapping = null,
+ min: Long = Long.MinValue,
+ max: Long = Long.MaxValue) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => Long = LongGen.remapRangeFunc(min, max)
+
+ override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping))
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): TimestampGenFunc =
+ TimestampGenFunc(mapping, min, max)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var lmin = TimestampGen.asLong(min)
+ var lmax = TimestampGen.asLong(max)
+ if (lmin > lmax) {
+ val tmp = lmin
+ lmin = lmax
+ lmax = tmp
+ }
+
+ if (lmin == lmax) {
+ SingleValGenFunc(lmin)
+ } else {
+ TimestampGenFunc(mapping, lmin, lmax)
+ }
+ }
+}
+
+object TimestampGen {
+ def asLong(a: Any): Long = a match {
+ case n: Byte => n.toLong
+ case n: Short => n.toLong
+ case n: Int => n.toLong
+ case n: Long => n
+ case i: Instant => DateTimeUtils.instantToMicros(i)
+ case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
+ case other =>
+ throw new IllegalArgumentException(s"a timestamp value range only supports " +
+ s"timestamp or integral values found $other")
+ }
+}
+
+class TimestampGen(conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+ override protected def getValGen: GeneratorFunction = TimestampGenFunc()
+
+ override def dataType: DataType = TimestampType
+}
+
+/**
+ * A value generator for Dates
+ */
+case class DateGenFunc(mapping: LocationToSeedMapping = null,
+ min: Int = Int.MinValue,
+ max: Int = Int.MaxValue) extends GeneratorFunction {
+
+ private lazy val valueRemapping: Long => Int = IntGen.remapRangeFunc(min, max)
+
+ override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping))
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): DateGenFunc =
+ DateGenFunc(mapping, min, max)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction = {
+ var imin = DateGen.asInt(min)
+ var imax = DateGen.asInt(max)
+ if (imin > imax) {
+ val tmp = imin
+ imin = imax
+ imax = tmp
+ }
+
+ if (imin == imax) {
+ SingleValGenFunc(imin)
+ } else {
+ DateGenFunc(mapping, imin, imax)
+ }
+ }
+}
+
+object DateGen {
+ def asInt(a: Any): Int = a match {
+ case n: Byte => n.toInt
+ case n: Short => n.toInt
+ case n: Int => n
+ case n: Long if n <= Int.MaxValue && n >= Int.MinValue => n.toInt
+ case ld: LocalDate => ld.toEpochDay.toInt
+ case d: Date => DateTimeUtils.fromJavaDate(d)
+ case other =>
+ throw new IllegalArgumentException(s"a date value range only supports " +
+ s"date or int values found $other")
+ }
+}
+
+class DateGen(conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+ override protected def getValGen: GeneratorFunction = DateGenFunc()
+
+ override def dataType: DataType = DateType
+}
+
+/**
+ * A value generator for Doubles
+ */
+case class DoubleGenFunc(mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+ override def apply(rowLoc: RowLocation): Any =
+ java.lang.Double.longBitsToDouble(DataGen.nextLong(rowLoc, mapping))
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): DoubleGenFunc =
+ DoubleGenFunc(mapping)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction =
+ throw new IllegalStateException("value ranges are not supported for Double yet")
+}
+
+class DoubleGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+
+ override def dataType: DataType = DoubleType
+
+ override protected def getValGen: GeneratorFunction = DoubleGenFunc()
+}
+
+/**
+ * A value generator for Floats
+ */
+case class FloatGenFunc(mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+ override def apply(rowLoc: RowLocation): Any =
+ java.lang.Float.intBitsToFloat(DataGen.getRandomFor(mapping(rowLoc)).nextInt())
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): FloatGenFunc =
+ FloatGenFunc(mapping)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction =
+ throw new IllegalStateException("value ranges are not supported for Float yet")
+}
+
+class FloatGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+
+ override def dataType: DataType = FloatType
+
+ override protected def getValGen: GeneratorFunction = FloatGenFunc()
+}
+
+case class ASCIIGenFunc(
+ lengthGen: LengthGeneratorFunction = null,
+ mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+
+ override def apply(rowLoc: RowLocation): Any = {
+ val len = lengthGen(rowLoc)
+ val r = DataGen.getRandomFor(rowLoc, mapping)
+ val buffer = new Array[Byte](len)
+ var at = 0
+ while (at < len) {
+ // Value range is 32 (Space) to 126 (~)
+ buffer(at) = (r.nextInt(126 - 31) + 32).toByte
+ at += 1
+ }
+ UTF8String.fromBytes(buffer, 0, len)
+ }
+
+ override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction =
+ ASCIIGenFunc(lengthGen, mapping)
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction =
+ ASCIIGenFunc(lengthGen, mapping)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction =
+ throw new IllegalArgumentException("value ranges are not supported for strings")
+}
+
+class StringGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+
+ override def dataType: DataType = StringType
+
+ override protected def getValGen: GeneratorFunction = ASCIIGenFunc()
+}
+
+case class StructGenFunc(childGens: Array[GeneratorFunction]) extends GeneratorFunction {
+ override def apply(rowLoc: RowLocation): Any = {
+ // The row location does not change for a struct
+ val data = childGens.map(_.apply(rowLoc))
+ InternalRow(data: _*)
+ }
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction =
+ this
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction =
+ throw new IllegalArgumentException("value ranges are not supported by structs")
+}
+
+class StructGen(val children: Seq[(String, DataGen)],
+ conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+
+ private lazy val dt = {
+ val childrenFields = children.map {
+ case (name, gen) =>
+ StructField(name, gen.dataType)
+ }
+ StructType(childrenFields)
+ }
+ override def dataType: DataType = dt
+
+ override def setCorrelatedKeyGroup(keyGroup: Long,
+ minSeed: Long, maxSeed: Long,
+ seedMapping: LocationToSeedMapping): DataGen = {
+ super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping)
+ children.foreach {
+ case (_, gen) =>
+ gen.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping)
+ }
+ this
+ }
+
+ override def get(name: String): Option[DataGen] =
+ children.collectFirst {
+ case (childName, dataGen) if childName.equalsIgnoreCase(name) => dataGen
+ }
+
+ override protected def getValGen: GeneratorFunction = {
+ val childGens = children.map(c => c._2.getGen).toArray
+ StructGenFunc(childGens)
+ }
+}
+
+case class ArrayGenFunc(
+ child: GeneratorFunction,
+ lengthGen: LengthGeneratorFunction = null,
+ mapping: LocationToSeedMapping = null) extends GeneratorFunction {
+
+ override def apply(rowLoc: RowLocation): Any = {
+ val len = lengthGen(rowLoc)
+ val data = new Array[Any](len)
+ val childRowLoc = rowLoc.withNewChild()
+ var i = 0
+ while (i < len) {
+ childRowLoc.setLastChildIndex(i)
+ data(i) = child(childRowLoc)
+ i += 1
+ }
+ ArrayData.toArrayData(data)
+ }
+
+ override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction =
+ ArrayGenFunc(child, lengthGen, mapping)
+
+ override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction =
+ ArrayGenFunc(child, lengthGen, mapping)
+
+ override def withValueRange(min: Any, max: Any): GeneratorFunction =
+ throw new IllegalArgumentException("value ranges are not supported for arrays")
+}
+
+class ArrayGen(child: DataGen,
+ conf: ColumnConf,
+ defaultValueRange: Option[(Any, Any)])
+ extends DataGen(conf, defaultValueRange) {
+
+ override def setCorrelatedKeyGroup(keyGroup: Long,
+ minSeed: Long, maxSeed: Long,
+ seedMapping: LocationToSeedMapping): DataGen = {
+ super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping)
+ child.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping)
+ this
+ }
+
+ override protected def getValGen: GeneratorFunction = ArrayGenFunc(child.getGen)
+
+ override def dataType: DataType = ArrayType(child.dataType, containsNull = child.nullable)
+
+ override def get(name: String): Option[DataGen] = {
+ if ("data".equalsIgnoreCase(name) || "child".equalsIgnoreCase(name)) {
+ Some(child)
+ } else {
+ None
+ }
+ }
+}
+
+object ColumnGen {
+ private def genInternal(rowNumber: Column,
+ dataType: DataType,
+ nullable: Boolean,
+ gen: GeneratorFunction): Column = {
+ Column(DataGenExpr(rowNumber.expr, dataType, nullable, gen))
+ }
+}
+
+/**
+ * Generates a top level column in a data set.
+ */
+class ColumnGen(val dataGen: DataGen) {
+ def setCorrelatedKeyGroup(kg: Long,
+ minSeed: Long, maxSeed: Long,
+ seedMapping: LocationToSeedMapping): ColumnGen = {
+ dataGen.setCorrelatedKeyGroup(kg, minSeed, maxSeed, seedMapping)
+ this
+ }
+
+ def setSeedRange(min: Long, max: Long): ColumnGen = {
+ dataGen.setSeedRange(min, max)
+ this
+ }
+
+ def setSeedMapping(seedMapping: LocationToSeedMapping): ColumnGen = {
+ dataGen.setSeedMapping(seedMapping)
+ this
+ }
+
+ def setNullSeedMapping(seedMapping: LocationToSeedMapping): ColumnGen = {
+ dataGen.setNullMapping(seedMapping)
+ this
+ }
+
+ def setValueRange(min: Any, max: Any): ColumnGen = {
+ dataGen.setValueRange(min, max)
+ this
+ }
+
+ def setNullProbability(probability: Double): ColumnGen = {
+ dataGen.setNullProbability(probability)
+ this
+ }
+
+ def setNullGen(f: NullGeneratorFunction): ColumnGen = {
+ dataGen.setNullGen(f)
+ this
+ }
+
+ def setValueGen(f: GeneratorFunction): ColumnGen = {
+ dataGen.setValueGen(f)
+ this
+ }
+
+ def setLengthGen(lengthGen: LengthGeneratorFunction): ColumnGen = {
+ dataGen.setLengthGen(lengthGen)
+ this
+ }
+
+ def setLength(len: Int): ColumnGen = {
+ dataGen.setLength(len)
+ this
+ }
+
+ final def apply(name: String): DataGen = {
+ get(name).getOrElse {
+ throw new IllegalArgumentException(s"$name not a child of $this")
+ }
+ }
+
+ def get(name: String): Option[DataGen] = dataGen.get(name)
+
+ def gen(rowNumber: Column): Column = {
+ ColumnGen.genInternal(rowNumber, dataGen.dataType, dataGen.nullable, dataGen.getGen)
+ }
+}
+
+sealed trait KeyGroupType
+
+/**
+ * A key group where all of the columns and sub-columns use the key group id instead of
+ * the column number when calculating the seed for the desired value. This makes all of
+ * the generator functions to compute the exact same seed and produce correlated results.
+ * For this to work the columns that will be joined to be configured with the same generators
+ * and the same value ranges (if any). But the distribution of the seeds and the seed range
+ * do not have to correspond.
+ * @param id the id of the key group, that replaces the column number.
+ * @param minSeed the min seed value for this key group
+ * @param maxSeed the max seed value for this key group
+ */
+case class CorrelatedKeyGroup(id: Long, minSeed: Long, maxSeed: Long) extends KeyGroupType
+
+/**
+ * A key group where the seed range is calculated based on the number of columns to
+ * get approximately the desired number of unique combinations. This will not be exact
+ * and does not work with nested types, or if the cardinality of the columns involved
+ * is not large enough to support the desired key range. Much of this can be fixed in
+ * the future.
+ *
+ * If you want to get partially overlapping groups you can increase the number of
+ * combinations, but again this is approximate so it might not really do exactly what you want.
+ * @param startSeed the seed that all combinatorial groups should start with
+ * @param combinations the desired number of unique combinations.
+ */
+case class CombinatorialKeyGroup(startSeed: Long, combinations: Long) extends KeyGroupType
+
+/**
+ * Used to generate a table in a database
+ */
+class TableGen(val columns: Seq[(String, ColumnGen)], numRows: Long) {
+ /**
+ * A Key group allows you to setup a multi-column key for things like a join where you
+ * want the keys to be generated as a group.
+ * @param names the names of the columns in the group. A column can only be a part of a
+ * single group.
+ * @param mapping the distribution/mapping for this key group
+ * @param groupType the type of key group that this is for.
+ * @return this for chaining
+ */
+ def configureKeyGroup(
+ names: Seq[String],
+ groupType: KeyGroupType,
+ mapping: LocationToSeedMapping): TableGen = {
+ groupType match {
+ case CorrelatedKeyGroup(id, minSeed, maxSeed) =>
+ names.foreach { name =>
+ val col = apply(name)
+ col.setCorrelatedKeyGroup(id, minSeed, maxSeed, mapping)
+ }
+ case CombinatorialKeyGroup(startSeed, combinations) =>
+ var numberOfCombinationsRemaining = combinations
+ var numberOfColumnsRemaining = names.length
+ names.foreach { name =>
+ val choices = math.pow(numberOfCombinationsRemaining.toDouble,
+ 1.0 / numberOfColumnsRemaining).toLong
+ if (choices <= 0) {
+ throw new IllegalArgumentException("Could not find a way to split up " +
+ s"${names.length} columns to produce $combinations unique combinations of values")
+ }
+ val column = apply(name)
+ column.setSeedRange(startSeed, startSeed + choices - 1)
+ numberOfColumnsRemaining -= 1
+ numberOfCombinationsRemaining = numberOfCombinationsRemaining / choices
+ }
+ }
+ this
+ }
+
+ /**
+ * Convert this table into a `DataFrame` that can be
+ * written out or used directly. Writing it out to parquet
+ * or ORC is probably better because you will not run into
+ * issues with the generation happening inline and performance
+ * being bad because this is not on the GPU.
+ * @param spark the session to use.
+ * @param numParts the number of parts to use (if > 0)
+ */
+ def toDF(spark: SparkSession, numParts: Int = 0): DataFrame = {
+ val id = col("id")
+ val allGens = columns.map {
+ case (name, childGen) =>
+ childGen.gen(id).alias(name)
+ }
+
+ val range = if (numParts > 0) {
+ spark.range(0, numRows, 1, numParts)
+ } else {
+ spark.range(numRows)
+ }
+ range.select(allGens: _*)
+ }
+
+ /**
+ * Get a ColumnGen for a named column.
+ * @param name the name of the column to look for
+ * @return the corresponding column gen
+ */
+ def apply(name: String): ColumnGen = {
+ get(name).getOrElse {
+ throw new IllegalArgumentException(s"$name not found: ${columns.map(_._1).mkString(" ")}")
+ }
+ }
+
+ def get(name: String): Option[ColumnGen] =
+ columns.collectFirst {
+ case (childName, colGen) if childName.equalsIgnoreCase(name) => colGen
+ }
+}
+
+/**
+ * Provides a way to map a DataType to a specific DataGen instance.
+ */
+trait TypeMapping {
+ /**
+ * Does this TypeMapping support the give data type.
+ * @return true if it does, else false
+ */
+ def canMap(dt: DataType, subTypeMapping: TypeMapping): Boolean
+
+ /**
+ * If canMap returned true, then do the mapping
+ * @return the DataGen along with the last ColumnConf that was used by a column
+ */
+ def map(dt: DataType,
+ conf: ColumnConf,
+ defaultRanges: mutable.HashMap[DataType, (Any, Any)],
+ subTypeMapping: TypeMapping): (DataGen, ColumnConf)
+}
+
+object DefaultTypeMapping extends TypeMapping {
+ override def canMap(dataType: DataType, subTypeMapping: TypeMapping): Boolean = dataType match {
+ case BooleanType => true
+ case ByteType => true
+ case ShortType => true
+ case IntegerType => true
+ case LongType => true
+ case _: DecimalType => true
+ case FloatType => true
+ case DoubleType => true
+ case StringType => true
+ case TimestampType => true
+ case DateType => true
+ case st: StructType =>
+ st.forall(child => subTypeMapping.canMap(child.dataType, subTypeMapping))
+ case at: ArrayType =>
+ subTypeMapping.canMap(at.elementType, subTypeMapping)
+ case _ => false
+ }
+
+ override def map(dataType: DataType,
+ conf: ColumnConf,
+ defaultRanges: mutable.HashMap[DataType, (Any, Any)],
+ subTypeMapping: TypeMapping): (DataGen, ColumnConf) = dataType match {
+ case BooleanType =>
+ (new BooleanGen(conf, defaultRanges.get(dataType)), conf)
+ case ByteType =>
+ (new ByteGen(conf, defaultRanges.get(dataType)), conf)
+ case ShortType =>
+ (new ShortGen(conf, defaultRanges.get(dataType)), conf)
+ case IntegerType =>
+ (new IntGen(conf, defaultRanges.get(dataType)), conf)
+ case LongType =>
+ (new LongGen(conf, defaultRanges.get(dataType)), conf)
+ case dt: DecimalType =>
+ (new DecimalGen(dt, conf, defaultRanges.get(dataType)), conf)
+ case FloatType =>
+ (new FloatGen(conf, defaultRanges.get(dataType)), conf)
+ case DoubleType =>
+ (new DoubleGen(conf, defaultRanges.get(dataType)), conf)
+ case StringType =>
+ (new StringGen(conf, defaultRanges.get(dataType)), conf)
+ case TimestampType =>
+ (new TimestampGen(conf, defaultRanges.get(dataType)), conf)
+ case DateType =>
+ (new DateGen(conf, defaultRanges.get(dataType)), conf)
+ case st: StructType =>
+ var tmpConf = conf
+ val fields = st.map { sf =>
+ tmpConf = tmpConf.forNextColumn(sf.nullable)
+ val genNCol = subTypeMapping.map(sf.dataType, tmpConf, defaultRanges, subTypeMapping)
+ tmpConf = genNCol._2
+ (sf.name, genNCol._1)
+ }
+ (new StructGen(fields, conf, defaultRanges.get(dataType)), tmpConf)
+ case at: ArrayType =>
+ val childConf = conf.forNextColumn(at.containsNull)
+ val child = subTypeMapping.map(at.elementType, childConf, defaultRanges, subTypeMapping)
+ (new ArrayGen(child._1, conf, defaultRanges.get(dataType)), child._2)
+ case other =>
+ throw new IllegalArgumentException(s"$other is not a supported type yet")
+ }
+}
+
+case class OrderedTypeMapping(ordered: Array[TypeMapping]) extends TypeMapping {
+ override def canMap(dt: DataType, subTypeMapping: TypeMapping): Boolean = {
+ ordered.foreach { mapping =>
+ if (mapping.canMap(dt, subTypeMapping)) {
+ return true
+ }
+ }
+ false
+ }
+
+ override def map(dt: DataType,
+ conf: ColumnConf,
+ defaultRanges: mutable.HashMap[DataType, (Any, Any)],
+ subTypeMapping: TypeMapping): (DataGen, ColumnConf) = {
+ ordered.foreach { mapping =>
+ if (mapping.canMap(dt, subTypeMapping)) {
+ return mapping.map(dt, conf, defaultRanges, subTypeMapping)
+ }
+ }
+ // This should not be reachable
+ throw new IllegalStateException(s"$dt is not currently supported")
+ }
+}
+
+object DBGen {
+ def empty: DBGen = new DBGen()
+ def apply(): DBGen = new DBGen()
+
+ private def dtToTopLevelGen(
+ st: StructType,
+ tableId: Int,
+ defaultRanges: mutable.HashMap[DataType, (Any, Any)],
+ numRows: Long,
+ mapping: OrderedTypeMapping): Seq[(String, ColumnGen)] = {
+ // a bit of a hack with the column num so that we update it before each time...
+ var conf = ColumnConf(ColumnLocation(tableId, -1), true, numRows)
+ st.toArray.map { sf =>
+ if (!mapping.canMap(sf.dataType, mapping)) {
+ throw new IllegalArgumentException(s"$sf is not supported at this time")
+ }
+ conf = conf.forNextColumn(sf.nullable)
+ val tmp = mapping.map(sf.dataType, conf, defaultRanges, mapping)
+ conf = tmp._2
+ (sf.name, new ColumnGen(tmp._1))
+ }
+ }
+}
+
+/**
+ * Set up the schema for different tables and the relationship between various keys/columns
+ * in the tables.
+ */
+class DBGen {
+ private var tableId = 0
+ private val tables = mutable.HashMap.empty[String, TableGen]
+ private val defaultRanges = mutable.HashMap.empty[DataType, (Any, Any)]
+ private val mappings = ArrayBuffer[TypeMapping](DefaultTypeMapping)
+
+ /**
+ * Set a default value range for all generators of a given type. Note that this only impacts
+ * tables that have not been added yet. Some generators don't support value ranges setting
+ * this for a type that does not support it will result in an error when creating one.
+ * Some generators can be configured to ignore this too, like if you pass in your own
+ * function for data generation. In those cases this may be ignored.
+ */
+ def setDefaultValueRange(dt: DataType, min: Any, max: Any): DBGen = {
+ defaultRanges.put(dt, (min, max))
+ this
+ }
+
+ /**
+ * Add a new table with a given input
+ * @param name the name of the table (must be unique)
+ * @param columns the generators that will be used for this table
+ * @return the TableGen that was added
+ */
+ private def addTable(name: String,
+ columns: Seq[(String, ColumnGen)],
+ numRows: Long): TableGen = {
+ val lowerName = name.toLowerCase
+ if (lowerName.contains(".")) {
+ // Not sure if there are other forbidden characters, but we can check on that later
+ throw new IllegalArgumentException("Name cannot contain '.' character")
+ }
+ if (tables.contains(lowerName)) {
+ throw new IllegalArgumentException("Cannot add duplicate tables (even if case is different)")
+ }
+ val ret = new TableGen(columns, numRows)
+ tables.put(lowerName, ret)
+ ret
+ }
+
+ /**
+ * Add a new table with a given type
+ * @param name the name of the table (must be unique)
+ * @param st the type for this table.
+ * @param numRows the number of rows for this table
+ * @return the TableGen that was just added
+ */
+ def addTable(name: String,
+ st: StructType,
+ numRows: Long): TableGen = {
+ val localTableId = tableId
+ tableId += 1
+ val mapping = OrderedTypeMapping(mappings.toArray)
+ val sg = DBGen.dtToTopLevelGen(st, localTableId, defaultRanges, numRows, mapping)
+ addTable(name, sg, numRows)
+ }
+
+ /**
+ * Add a new table with a type defined by the DDL
+ * @param name the name of the table (must be unique)
+ * @param ddl the DDL that describes the type for the table.
+ * @param numRows the number of rows for this table
+ * @return the TableGen that was just created
+ */
+ def addTable(name: String,
+ ddl: String,
+ numRows: Long): TableGen = {
+ val localTableId = tableId
+ tableId += 1
+ val mapping = OrderedTypeMapping(mappings.toArray)
+ val sg = DBGen.dtToTopLevelGen(
+ DataType.fromDDL(ddl).asInstanceOf[StructType], localTableId, defaultRanges, numRows, mapping)
+ addTable(name, sg, numRows)
+ }
+
+ /**
+ * Get a TableGen by name
+ * @param name the name of the table to look for
+ * @return the corresponding TableGen
+ */
+ def apply(name: String): TableGen = tables(name.toLowerCase)
+
+ def get(name: String): Option[TableGen] = tables.get(name.toLowerCase)
+
+ /**
+ * Names of the tables.
+ */
+ def tableNames: Seq[String] = tables.keys.toSeq
+
+ /**
+ * Get an immutable map out of this.
+ */
+ def toMap: Map[String, TableGen] = tables.toMap
+
+ /**
+ * Convert all of the tables into dataframes
+ * @param spark the session to use for the conversion
+ * @param numParts the number of parts (tasks) to use. <= 0 uses the same number of tasks
+ * as the cluster has.
+ * @return a Map of the name of the table to the dataframe.
+ */
+ def toDF(spark: SparkSession, numParts: Int = 0): Map[String, DataFrame] =
+ toMap.map {
+ case (name, gen) => (name, gen.toDF(spark, numParts))
+ }
+
+ /**
+ * Take all of the tables and create or replace temp views with them using the given name.
+ * The data will be generated inline as the query runs.
+ * @param spark the session to use
+ * @param numParts the number of parts to use, if > 0
+ */
+ def createOrReplaceTempViews(spark: SparkSession, numParts: Int = 0): Unit =
+ toDF(spark, numParts).foreach {
+ case (name, df) => df.createOrReplaceTempView(name)
+ }
+
+ /**
+ * Write all of the tables out as parquet under path/table_name and overwrite anything that is
+ * already there.
+ * @param spark the session to use
+ * @param path the base path to write the data under
+ * @param numParts the number of parts to use, if > 0
+ * @param overwrite if true will overwrite existing data
+ */
+ def writeParquet(spark: SparkSession,
+ path: String,
+ numParts: Int = 0,
+ overwrite: Boolean = false): Unit = {
+ toDF(spark, numParts).foreach {
+ case (name, df) =>
+ val subPath = path + "/" + name
+ var writer = df.write
+ if (overwrite) {
+ writer = writer.mode("overwrite")
+ }
+ writer.parquet(subPath)
+ }
+ }
+
+ /**
+ * Create or replace temp views for all of the tables assuming that they were already written
+ * out as parquet under path/table_name.
+ * @param spark the session to use
+ * @param path the base path to read the data from
+ */
+ def createOrReplaceTempViewsFromParquet(spark: SparkSession, path: String): Unit = {
+ tables.foreach {
+ case (name, _) =>
+ val subPath = path + "/" + name
+ spark.read.parquet(subPath).createOrReplaceTempView(name)
+ }
+ }
+
+ /**
+ * Write all of the tables out as parquet under path/table_name and then create or replace temp
+ * views for each of the tables using the given name.
+ *
+ * @param spark the session to use
+ * @param path the base path to write the data under
+ * @param numParts the number of parts to use, if > 0
+ * @param overwrite if true will overwrite existing data
+ */
+ def writeParquetAndReplaceTempViews(spark: SparkSession,
+ path: String,
+ numParts: Int = 0,
+ overwrite: Boolean = false): Unit = {
+ writeParquet(spark, path, numParts, overwrite)
+ createOrReplaceTempViewsFromParquet(spark, path)
+ }
+
+ /**
+ * Write all of the tables out as orc under path/table_name and overwrite anything that is
+ * already there.
+ *
+ * @param spark the session to use
+ * @param path the base path to write the data under
+ * @param numParts the number of parts to use, if > 0
+ * @param overwrite if true will overwrite existing data
+ */
+ def writeOrc(spark: SparkSession,
+ path: String,
+ numParts: Int = 0,
+ overwrite: Boolean = false): Unit = {
+ toDF(spark, numParts).foreach {
+ case (name, df) =>
+ val subPath = path + "/" + name
+ var writer = df.write
+ if (overwrite) {
+ writer = writer.mode("overwrite")
+ }
+ writer.orc(subPath)
+ }
+ }
+
+ /**
+ * Create or replace temp views for all of the tables assuming that they were already written
+ * out as orc under path/table_name.
+ * @param spark the session to use
+ * @param path the base path to read the data from
+ */
+ def createOrReplaceTempViewsFromOrc(spark: SparkSession, path: String): Unit = {
+ tables.foreach {
+ case (name, _) =>
+ val subPath = path + "/" + name
+ spark.read.orc(subPath).createOrReplaceTempView(name)
+ }
+ }
+
+ /**
+ * Write all of the tables out as orc under path/table_name and then create or replace temp
+ * views for each of the tables using the given name.
+ *
+ * @param spark the session to use
+ * @param path the base path to write the data under
+ * @param numParts the number of parts to use, if > 0
+ * @param overwrite if true will overwrite existing data
+ */
+ def writeOrcAndReplaceTempViews(spark: SparkSession,
+ path: String,
+ numParts: Int = 0,
+ overwrite: Boolean = false): Unit = {
+ writeOrc(spark, path, numParts, overwrite)
+ createOrReplaceTempViewsFromOrc(spark, path)
+ }
+
+ /**
+ * Add a new user controlled type mapping. This allows
+ * the user to totally override the handling for any or all types.
+ * @param mapping the new mapping to add with highest priority.
+ */
+ def addTypeMapping(mapping: TypeMapping): Unit = {
+ // Insert this mapping in front of the others
+ mappings.insert(0, mapping)
+ }
+}
diff --git a/integration_tests/src/main/spark311/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala b/integration_tests/src/main/spark311/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala
new file mode 100644
index 000000000000..d50008f7fb72
--- /dev/null
+++ b/integration_tests/src/main/spark311/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "311"}
+{"spark": "312"}
+{"spark": "313"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.tests.datagen
+
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+
+trait DataGenExprBase extends UnaryExpression with ExpectsInputTypes with CodegenFallback
diff --git a/integration_tests/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala b/integration_tests/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala
new file mode 100644
index 000000000000..3d2e03c50f1e
--- /dev/null
+++ b/integration_tests/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "320"}
+{"spark": "321"}
+{"spark": "321cdh"}
+{"spark": "321db"}
+{"spark": "322"}
+{"spark": "323"}
+{"spark": "324"}
+{"spark": "330"}
+{"spark": "330cdh"}
+{"spark": "330db"}
+{"spark": "331"}
+{"spark": "332"}
+{"spark": "332db"}
+{"spark": "333"}
+{"spark": "340"}
+{"spark": "341"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.tests.datagen
+
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+
+trait DataGenExprBase extends UnaryExpression with ExpectsInputTypes with CodegenFallback {
+ override def withNewChildInternal(newChild: Expression): Expression =
+ legacyWithNewChildren(Seq(newChild))
+}
diff --git a/pom.xml b/pom.xml
index d12c535e44cb..c5ee3efecade 100644
--- a/pom.xml
+++ b/pom.xml
@@ -524,7 +524,6 @@
11
11
11
- true
@@ -534,7 +533,6 @@
17
17
17
- true
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index a724afb45b84..5705e38f8c93 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1128,7 +1128,7 @@ object GpuOverrides extends Logging {
}),
expr[SecondsToTimestamp](
"Converts the number of seconds from unix epoch to a timestamp",
- ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
+ ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
TypeSig.gpuNumeric, TypeSig.cpuNumeric),
(a, conf, p, r) => new UnaryExprMeta[SecondsToTimestamp](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
@@ -1136,7 +1136,7 @@ object GpuOverrides extends Logging {
}),
expr[MillisToTimestamp](
"Converts the number of milliseconds from unix epoch to a timestamp",
- ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
+ ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
TypeSig.integral, TypeSig.integral),
(a, conf, p, r) => new UnaryExprMeta[MillisToTimestamp](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
@@ -1144,7 +1144,7 @@ object GpuOverrides extends Logging {
}),
expr[MicrosToTimestamp](
"Converts the number of microseconds from unix epoch to a timestamp",
- ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
+ ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
TypeSig.integral, TypeSig.integral),
(a, conf, p, r) => new UnaryExprMeta[MicrosToTimestamp](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
@@ -2108,14 +2108,12 @@ object GpuOverrides extends Logging {
"Max aggregate operator",
ExprChecksImpl(
ExprChecks.reductionAndGroupByAgg(
- (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
- .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
- TypeSig.STRUCT + TypeSig.ARRAY),
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT +
+ TypeSig.ARRAY).nested(),
TypeSig.orderable,
Seq(ParamCheck("input",
- (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
- .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
- TypeSig.STRUCT + TypeSig.ARRAY),
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT +
+ TypeSig.ARRAY).nested(),
TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
@@ -2135,14 +2133,12 @@ object GpuOverrides extends Logging {
"Min aggregate operator",
ExprChecksImpl(
ExprChecks.reductionAndGroupByAgg(
- (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
- .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
- TypeSig.STRUCT + TypeSig.ARRAY),
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT +
+ TypeSig.ARRAY).nested(),
TypeSig.orderable,
Seq(ParamCheck("input",
- (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
- .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
- TypeSig.STRUCT + TypeSig.ARRAY),
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT +
+ TypeSig.ARRAY).nested(),
TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala
index d1dbcd3a0863..73b280aadd5c 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala
@@ -60,10 +60,12 @@ case class GpuParseUrl(children: Seq[Expression],
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size > 3 || children.size < 2) {
- RapidsErrorUtils.parseUrlWrongNumArgs(children.size)
- } else {
- super[ExpectsInputTypes].checkInputDataTypes()
+ RapidsErrorUtils.parseUrlWrongNumArgs(children.size) match {
+ case res: Some[TypeCheckResult] => return res.get
+ case _ => // error message has been thrown
+ }
}
+ super[ExpectsInputTypes].checkInputDataTypes()
}
private def getPattern(key: UTF8String): RegexProgram = {
diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
index 9259f0d3e90c..5d74bbac86aa 100644
--- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
+++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
@@ -84,11 +84,11 @@ object RapidsErrorUtils {
throw new AnalysisException(s"$tableIdentifier already exists.")
}
- def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = {
- TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")
+ def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = {
+ Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments"))
}
- def invalidUrlException(url: UFT8String, e: Throwable): Throwable = {
+ def invalidUrlException(url: UTF8String, e: Throwable): Throwable = {
new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e)
}
}
diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
index 6a569d98248f..115d4e93ba80 100644
--- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
+++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
@@ -90,8 +90,8 @@ object RapidsErrorUtils {
QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier)
}
- def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = {
- TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")
+ def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = {
+ Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments"))
}
def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = {
diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
index b5997e4e6dbb..b7b95016047c 100644
--- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
+++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
@@ -88,8 +88,8 @@ object RapidsErrorUtils {
QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier)
}
- def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = {
- TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")
+ def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = {
+ Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments"))
}
def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = {
diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
index a0e827150d5a..841b3a960782 100644
--- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
+++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
@@ -84,8 +84,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus {
new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
}
- def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = {
- TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")
+ def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = {
+ Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments"))
}
def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = {
diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
index 43be246548ac..aaaecc86256a 100644
--- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
+++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
@@ -92,8 +92,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus {
QueryExecutionErrors.intervalDividedByZeroError(origin.context)
}
- def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = {
- TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")
+ def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = {
+ Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments"))
}
def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = {
diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
index 6dde4e0a8f98..032fed254efd 100644
--- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
+++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
@@ -25,7 +25,7 @@ import java.net.URISyntaxException
import org.apache.spark.SparkDateTimeException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext}
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}
import org.apache.spark.unsafe.types.UTF8String
@@ -92,10 +92,11 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus {
QueryExecutionErrors.intervalDividedByZeroError(origin.context)
}
- def parseUrlWrongNumArgs(actual: Int): Throwable = {
+ def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = {
throw QueryCompilationErrors.wrongNumArgsError(
- "parse_url", Seq("[2, 3]"), actualNumber
+ "parse_url", Seq("[2, 3]"), actual
)
+ None
}
def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = {
diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv
index 4f01e3b73793..98437662e2cc 100644
--- a/tools/generated_files/supportedExprs.csv
+++ b/tools/generated_files/supportedExprs.csv
@@ -649,16 +649,16 @@ Last,S,`last`; `last_value`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,P
Last,S,`last`; `last_value`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS
Last,S,`last`; `last_value`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS
Last,S,`last`; `last_value`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS
-Max,S,`max`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
-Max,S,`max`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
-Max,S,`max`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
-Max,S,`max`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
+Max,S,`max`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
+Max,S,`max`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
+Max,S,`max`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
+Max,S,`max`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
Max,S,`max`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS
Max,S,`max`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS
-Min,S,`min`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
-Min,S,`min`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
-Min,S,`min`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
-Min,S,`min`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS
+Min,S,`min`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
+Min,S,`min`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
+Min,S,`min`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
+Min,S,`min`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS
Min,S,`min`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS
Min,S,`min`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS
PivotFirst,S, ,None,aggregation,pivotColumn,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS