diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 953597c8016d..b201d6952829 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -17195,7 +17195,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -17216,7 +17216,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -17238,7 +17238,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -17259,7 +17259,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -17328,7 +17328,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -17349,7 +17349,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -17371,7 +17371,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -17392,7 +17392,7 @@ are limited. S NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS diff --git a/integration_tests/DATA_GEN.md b/integration_tests/DATA_GEN.md new file mode 100644 index 000000000000..3db42716df2b --- /dev/null +++ b/integration_tests/DATA_GEN.md @@ -0,0 +1,444 @@ +# Big Data Generation + +In order to do scale testing we need a way to generate lots of data in a +deterministic way that gives us control over the number of unique values +in a column, the skew of the values in a column, and the correlation of +data between tables for joins. To accomplish this we wrote +`org.apache.spark.sql.tests.datagen`. + +## Setup Environment + +To get started with big data generation the first thing you need to do is +to include the appropriate jar on the classpath for your version of Apache Spark. +Note that this does not run on the GPU, but it does use +parts of the shim framework that the RAPIDS Accelerator does, so it is currently in +the integration tests jar for the RAPIDS Accelerator. The jar is specific to the +version of Spark you are using and is not pushed to Maven Central. Because of this +you will have to build it from source yourself. + +```shell +cd integration_tests +mvn clean package -Drat.skip -DskipTests -Dbuildver=$SPARK_VERSION +``` + +Where `$SPARK_VERSION` is a compressed version number, like 330 for Spark 3.3.0. + +If you are building with a jdk version that is not 8, you will need to add in the +corresponding profile flag `-P` + +After this the jar should be at +`target/rapids-4-spark-integration-tests_2.12-$PLUGIN_VERSION-spark$SPARK_VERSION.jar` +for example a Spark 3.3.0 jar for the 23.08.0 release would be +`target/rapids-4-spark-integration-tests_2.12-23.08.0-spark330.jar` + +To get a spark shell with this you can run +```shell +spark-shell --jars target/rapids-4-spark-integration-tests_2.12-23.08.0-spark330.jar +``` + +After that you should be good to go. + +## Generate Some Data + +The first thing to do is to import the classes + +```scala +import org.apache.spark.sql.tests.datagen._ +``` + +After this the main entry point is `DBGen`. `DBGen` provides a way to generate +multiple tables, but we are going to start off with just one. To do this we can +call `addTable` on it. + +```scala +val dataTable = DBGen().addTable("data", "a string, b byte", 5) +dataTable.toDF(spark).show() ++----------+----+ +| a| b| ++----------+----+ +|t=qIHOf:O)| 47| +|yXT-j", 3) +dataTable("a").setLength(1) +dataTable("b").setLength(2) +dataTable("b")("data").setLength(3) +dataTable.toDF(spark).show(false) ++---+----------+ +|a |b | ++---+----------+ +|t |[X]6, /= minRow) { + null + } else { + gen(rowLoc) + } + } +} + +... + +dataTable("a").setNullGen(MyNullGen(1024, 9999999L)) +``` + +Similarly, if you have a requirement to generate JSON formatted strings that +follow a given pattern you can do that. Or provide a distribution where a very +specific seed shows up 99% of the time, and the rest of the time it falls back +to the regular `FlatDistribution`, you can also do that. It is designed to be very +flexible. \ No newline at end of file diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml index 8b742a020afe..9cea03a0fd15 100644 --- a/integration_tests/pom.xml +++ b/integration_tests/pom.xml @@ -228,12 +228,24 @@ ${spark.version} provided + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${spark.version} + provided + org.scala-lang scala-reflect ${scala.version} provided + + com.esotericsoftware.kryo + kryo-shaded-db + ${spark.version} + provided + org.apache.arrow arrow-format diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index fd7ff4a4b494..ac96f02d01d2 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -688,7 +688,7 @@ def test_hash_groupby_collect_set(data_gen): @ignore_order(local=True) @pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op, ids=idfn) -@pytest.mark.xfail(condition=is_before_spark_330(), reason='https://github.com/NVIDIA/spark-rapids/issues/8716') +@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/8716') def test_hash_groupby_collect_set_on_nested_type(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, data_gen, length=100) @@ -731,7 +731,7 @@ def test_hash_reduction_collect_set(data_gen): @ignore_order(local=True) @pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op, ids=idfn) -@pytest.mark.xfail(condition=is_before_spark_330(), reason='https://github.com/NVIDIA/spark-rapids/issues/8716') +@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/8716') def test_hash_reduction_collect_set_on_nested_type(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, data_gen, length=100) @@ -1869,7 +1869,7 @@ def test_std_variance_partial_replace_fallback(data_gen, conf=local_conf) # -# test min/max aggregations for structs +# Test min/max aggregations on simple type (integer) keys and nested type values. # gens_for_max_min = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, @@ -1877,16 +1877,11 @@ def test_std_variance_partial_replace_fallback(data_gen, date_gen, timestamp_gen, DecimalGen(precision=12, scale=2), DecimalGen(precision=36, scale=5), - null_gen] + array_gens_sample # + struct_gens_sample -# Nested structs have issues, https://github.com/NVIDIA/spark-rapids/issues/8702 + null_gen] + array_gens_sample + struct_gens_sample @ignore_order(local=True) @pytest.mark.parametrize('data_gen', gens_for_max_min, ids=idfn) -def test_min_max_for_struct(data_gen): - df_gen = [ - ('a', StructGen([ - ('aa', data_gen), - ('ab', data_gen)])), - ('b', RepeatSeqGen(IntegerGen(), length=20))] +def test_min_max_in_groupby_and_reduction(data_gen): + df_gen = [('a', data_gen), ('b', RepeatSeqGen(IntegerGen(), length=20))] # test max assert_gpu_and_cpu_are_equal_sql( diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 8c426aa1a921..159e77d13306 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -730,20 +730,44 @@ def write_partitions(spark, table_path): conf={} ) -# Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693 -def test_hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path): +def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func): + conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase} def create_table(spark, path): tmp_table = spark_tmp_table_factory.get() spark.sql(f"CREATE TABLE {tmp_table} STORED AS PARQUET " + - f""" LOCATION '{path}' AS SELECT CAST('2015-01-01 00:00:00' AS TIMESTAMP) as t; """) + f""" LOCATION '{path}' AS SELECT CAST('2015-01-01 00:00:00' AS TIMESTAMP) as t; """) def read_table(spark, path): return spark.read.parquet(path) data_path = spark_tmp_path + '/PARQUET_DATA' - assert_gpu_and_cpu_writes_are_equal_collect(create_table, read_table, data_path) - assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.parquet(data_path + '/CPU')) + + func(create_table, read_table, data_path, conf) + +# Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693 +def test_hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path): + + def func_test(create_table, read_table, data_path, conf): + assert_gpu_and_cpu_writes_are_equal_collect(create_table, read_table, data_path, conf=conf) + assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.parquet(data_path + '/CPU')) + + hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, 'CORRECTED', func_test) + +# Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693 +@allow_non_gpu('DataWritingCommandExec', 'WriteFilesExec') +def test_hive_timestamp_value_fallback(spark_tmp_table_factory, spark_tmp_path): + + def func_test(create_table, read_table, data_path, conf): + assert_gpu_fallback_write( + create_table, + read_table, + data_path, + ['DataWritingCommandExec'], + conf) + + hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, 'LEGACY', func_test) @ignore_order @pytest.mark.skipif(is_before_spark_340(), reason="`spark.sql.optimizer.plannedWrite.enabled` is only supported in Spark 340+") diff --git a/integration_tests/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala b/integration_tests/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala new file mode 100644 index 000000000000..27b47e7900c2 --- /dev/null +++ b/integration_tests/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala @@ -0,0 +1,1974 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.tests.datagen + +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.math.BigDecimal.RoundingMode +import scala.util.Random + +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, XXH64} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.random.XORShiftRandom + +/** + * Holds a representation of the current row location. + * @param rowNum the top level row number (starting at 0 but this might change) + * @param subRows a path to the child rows in an array starting at 0 for each array. + */ +class RowLocation(val rowNum: Long, val subRows: Array[Int] = null) { + def withNewChild(): RowLocation = { + val newSubRows = if (subRows == null) { + new Array[Int](1) + } else { + val tmp = new Array[Int](subRows.length + 1) + subRows.copyToArray(tmp, 0, subRows.length) + tmp + } + new RowLocation(rowNum, newSubRows) + } + + def setLastChildIndex(index: Int): Unit = + subRows(subRows.length - 1) = index + + /** + * Hash the location into a single long value. + */ + def hashLoc(seed: Long): Long = { + var tmp = XXH64.hashLong(rowNum, seed) + if (subRows != null) { + var i = 0 + while (i < subRows.length) { + tmp = XXH64.hashLong(subRows(i), tmp) + i += 1 + } + } + tmp + } +} + +/** + * Holds a representation of the current columns location. This is computed by walking + * the tree. The path is depth first, but as we pass a new node we assign it an ID as + * we go starting at 0. If this is a part of a correlated key group then the + * correlatedKeyGroup will be set and the columnNum will be ignored when computing the + * hash. This makes the generated data correlated for all column/child columns. + * @param tableNum a unique ID for the table this is a part of. + * @param columnNum the location of the column in the data being generated + * @param correlatedKeyGroup the correlated key group this column is a part of, if any. + */ +case class ColumnLocation(tableNum: Int, columnNum: Int, correlatedKeyGroup: Option[Long] = None) { + def forNextColumn(): ColumnLocation = ColumnLocation(tableNum, columnNum + 1) + + + /** + * Create a new ColumnLocation that is specifically for a given key group + */ + def forCorrelatedKeyGroup(keyGroup: Long): ColumnLocation = + ColumnLocation(tableNum, columnNum, Some(keyGroup)) + + /** + * Hash the location into a single long value. + */ + lazy val hashLoc: Long = XXH64.hashLong(tableNum, correlatedKeyGroup.getOrElse(columnNum)) +} + +/** + * Holds configuration for a given column, or sub-column. + * @param columnLoc the location of the column + * @param nullable are nulls supported in the output or not. + * @param numRows the number of rows that this table is going to have, not necessarily the number + * that the column is going to have. + * @param minSeed the minimum seed value allowed to be returned + * @param maxSeed the maximum seed value allowed to be returned + */ +case class ColumnConf(columnLoc: ColumnLocation, + nullable: Boolean, + numTableRows: Long, + minSeed: Long = Long.MinValue, + maxSeed: Long = Long.MaxValue) { + + def forNextColumn(nullable: Boolean): ColumnConf = + ColumnConf(columnLoc.forNextColumn(), nullable, numTableRows) + + /** + * Create a new configuration based on this, but for a given correlated key group. + */ + def forCorrelatedKeyGroup(correlatedKeyGroup: Long): ColumnConf = { + ColumnConf(columnLoc.forCorrelatedKeyGroup(correlatedKeyGroup), + nullable, + numTableRows, + minSeed, + maxSeed) + } + + /** + * Create a new configuration based on this, but for a given seed range. + */ + def forSeedRange(min: Long, max: Long): ColumnConf = + ColumnConf(columnLoc, nullable, numTableRows, min, max) + + /** + * Create a new configuration based on this, but for the null generator. + */ + def forNulls: ColumnConf = { + assert(nullable, "Should not get a conf for nulls from a non-nullable column conf") + ColumnConf(columnLoc, nullable, numTableRows) + } +} + +/** + * Provides a mapping between a location + configuration and a seed. This provides a random + * looking mapping between a location and a seed that will be used to deterministically + * generate data. This is also where we can change the distribution and cardinality of data. + * A min/max seed allows for changes to the cardinality of values generated and the exact + * mapping done can inject skew or various types of data distributions. + */ +trait LocationToSeedMapping extends Serializable { + /** + * Set the config for the column that should be used along with a min/max type seed range. + * This will be called before apply is ever called. + */ + def withColumnConf(colConf: ColumnConf): LocationToSeedMapping + + /** + * Given a row location + the previously set config produce a seed to be used. + * @param rowLoc the row location + * @return the seed to use + */ + def apply(rowLoc: RowLocation): Long +} + +object LocationToSeedMapping { + /** + * Return a function that should remap the full range of long seed + * (Long.MinValue to Long.MaxValue) to a new range as described by the conf in a way that + * does not really impact the distribution. This is not 100% correct + * because if the min/max is not the full range the mapping cannot always guarantee that + * each option gets exactly the same probabilty of being produced up. + * But it should be good enough. + */ + def remapRangeFunc(colConf: ColumnConf): Long => Long = { + val minSeed = colConf.minSeed + val maxSeed = colConf.maxSeed + if (minSeed == Long.MinValue && maxSeed == Long.MaxValue) { + n => n + } else { + // We generate numbers between minSeed and maxSeed + 1, and then take the floor so that + // values on the ends have the same range as those in the middle + val scaleFactor = (BigDecimal(maxSeed) + 1 - minSeed) / + (BigDecimal(Long.MaxValue) - Long.MinValue) + val minBig = BigDecimal(Long.MinValue) + n => { + (((n - minBig) * scaleFactor) + minSeed).setScale(0, RoundingMode.FLOOR).toLong + } + } + } +} + +/** + * The idea is that we have an address (`RowLocation`, `ColumnLocation`) + * that will uniquely identify the location of a piece of data to be generated. + * A `GeneratorFunction` should distinctly map this address + configs to a + * generated value. The column location is computed from the schema + configs + * and will be set by calling `setColumnConf` along with other configs before + * the function is used to generate any data. The row locations is created as needed + * and could be reused in between calls to `apply`. This should keep no + * state in between calls to `apply`. + * + * But we also want to have control over various aspects of how the data + * is generated. We want + * 1. The data to appear random (pseudo random) + * 2. The number of unique values (cardinality) + * 3. The distribution of those values (skew/etc) + * + * To accomplish this we also provide a location to seed mapping through the + * `setLocationToSeedMapping` API. This too is guaranteed to be called before + * `apply` is called. + */ +trait GeneratorFunction extends Serializable { + /** + * Set a location mapping function. This will be called before apply is ever called. + */ + def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction + + /** + * Set a new value range. This will be called if a value range is set by the user before + * apply is called. If this generator does not support a value range, then an exception + * should be thrown. + */ + def withValueRange(min: Any, max: Any): GeneratorFunction + + /** + * Set a LengthGeneratorFunction for this generator. Not all types need a length so this + * should only be overridden if it is needed. This will be called before apply is called. + */ + def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = { + this + } + + /** + * Actually generate the data. The data should be based off of the mapping and the row location + * passed in. The output data needs to correspond to something that Spark will be able + * to understand properly as the datatype this generator is for. Because we use an expression + * to call this, it needs to be one of the "internal" types that Spark expects. + */ + def apply(rowLoc: RowLocation): Any +} + +/** + * We want a way to be able to insert nulls as needed. We don't treat nulls + * the same as other values and they are generally off by default. This + * becomes a wrapper around another GeneratorFunction and decides + * if a null should be returned before calling into the other function or not. + * This will only be put into places where the schema allows it to be, and + * where the user has configured it to be. + */ +trait NullGeneratorFunction extends GeneratorFunction { + /** + * Create a new NullGeneratorFunction based on this one, but wrapping the passed in + * generator. This will be called before apply is called. + */ + def withWrapped(gen: GeneratorFunction): NullGeneratorFunction + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("Null Generators should not have a value range set for them") + + /** + * Create a new NullGeneratorFunction bases on this one, but with the given mapping. + * This is guaranteed to be called before apply is called. + */ + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): NullGeneratorFunction +} + +/** + * Used to generate the length of a String, Array, Map, or any other variable length data. + * This is treated as a separate concern from the data generation because it is hard to + * not mess up the cardinality using a naive approach to generating a length. If you just + * use something like Random.nextInt(maxLen) then you can get skew because each length + * has a different maximum cardinality possible. This causes skew. For example if you want + * length between 0 and 1, then half of the rows generated will be length 0 and half + * will be length 1. So half the data will be a single value. + */ +trait LengthGeneratorFunction { + def withLocationToSeedMapping(mapping: LocationToSeedMapping): LengthGeneratorFunction + def apply(rowLoc: RowLocation): Int +} + +/** + * Just generate all of the data with a single length. This is the simplest way to avoid + * skew because of the different possible cardinality for the different lengths. + */ +case class FixedLengthGeneratorFunction(length: Int) extends LengthGeneratorFunction { + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): LengthGeneratorFunction = + this + + override def apply(rowLoc: RowLocation): Int = length +} + +/** + * Generate nulls with a given probability. + * @param prob 0.0 to 1.0 for how often nulls should appear in the output. + */ +case class NullProbabilityGenerationFunction(prob: Double, + gen: GeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends NullGeneratorFunction { + + override def withWrapped(gen: GeneratorFunction): NullProbabilityGenerationFunction = + NullProbabilityGenerationFunction(prob, gen, mapping) + + override def withLocationToSeedMapping( + mapping: LocationToSeedMapping): NullProbabilityGenerationFunction = + NullProbabilityGenerationFunction(prob, gen, mapping) + + override def apply(rowLoc: RowLocation): Any = { + val r = DataGen.getRandomFor(rowLoc, mapping) + if (r.nextDouble() <= prob) { + null + } else { + gen(rowLoc) + } + } +} + +/** + * A LocationToSeedMapping that generates seeds with an approximately equal chance for + * all values. + */ +case class FlatDistribution(colLocSeed: Long = 0L, + remapRangeFunc: Long => Long = n => n) extends LocationToSeedMapping { + + override def withColumnConf(colConf: ColumnConf): FlatDistribution = { + val colLocSeed = colConf.columnLoc.hashLoc + val remapRangeFunc = LocationToSeedMapping.remapRangeFunc(colConf) + FlatDistribution(colLocSeed, remapRangeFunc) + } + + override def apply(rowLoc: RowLocation): Long = + remapRangeFunc(rowLoc.hashLoc(colLocSeed)) +} + +/** + * A LocationToSeedMapping that generates a unique seed per row. The order in which the values are + * generated is some what of a random like order. This should *NOT* be applied to any + * columns that are under an array because it ues rowNum and assumes that this maps correctly + * to the number of rows being generated. + */ +case class DistinctDistribution(numTableRows: Long = Long.MaxValue, + columnLocSeed: Long = 0L, + minSeed: Long = Long.MinValue, + maxSeed: Long = Long.MaxValue) extends LocationToSeedMapping { + + override def withColumnConf(colConf: ColumnConf): DistinctDistribution = { + val numTableRows = colConf.numTableRows + val maskLen = 64 - java.lang.Long.numberOfLeadingZeros(numTableRows) + val maxMask = (1 << maskLen) - 1 + val columnLocSeed = colConf.columnLoc.hashLoc & maxMask + val minSeed = colConf.minSeed + val maxSeed = colConf.maxSeed + DistinctDistribution(numTableRows, columnLocSeed, minSeed, maxSeed) + } + + override def apply(rowLoc: RowLocation): Long = { + assert(rowLoc.subRows == null, + "DistinctDistribution cannot be applied to columns under an Array") + val range = BigInt(maxSeed) - minSeed + val modded = BigInt(rowLoc.rowNum).mod(range) + val rowSeed = (modded + minSeed).toLong + val ret = rowSeed ^ columnLocSeed + if (ret > numTableRows) { + rowSeed + } else { + ret + } + } +} + +object DataGen { + val rLocal = new ThreadLocal[Random] { + override def initialValue(): Random = new XORShiftRandom() + } + + /** + * Get a Random instance that is based off of a given seed. This should not be kept in + * between calls to apply, and should not be used after calling a child methods apply. + */ + def getRandomFor(locSeed: Long): Random = { + val r = rLocal.get() + r.setSeed(locSeed) + r + } + + /** + * Get a Random instance that is based off of the given location and mapping. + */ + def getRandomFor(rowLoc: RowLocation, mapping: LocationToSeedMapping): Random = + getRandomFor(mapping(rowLoc)) + + /** + * Get a long value that is mapped from the given location and mapping. + */ + def nextLong(rowLoc: RowLocation, mapping: LocationToSeedMapping): Long = + getRandomFor(rowLoc, mapping).nextLong() + + /** + * Produce a remapping function that remaps a full range of long values to a new range evenly. + */ + def remapRangeFunc(typeMinVal: Long, typeMaxVal: Long): Long => Long = { + if (typeMinVal == Long.MinValue && typeMaxVal == Long.MaxValue) { + n => n + } else { + // We generate numbers between typeMinVal and typeMaxVal + 1, and then take the floor so that + // all numbers get an even chance. + val scaleFactor = (BigDecimal(typeMaxVal) + 1 - typeMinVal) / + (BigDecimal(Long.MaxValue) - Long.MinValue) + val minBig = BigDecimal(Long.MinValue) + n => { + (((n - minBig) * scaleFactor) + typeMinVal).setScale(0, RoundingMode.FLOOR).toLong + } + } + } +} + +/** + * An expression that actually does the data generation based on an input row number. + */ +case class DataGenExpr(child: Expression, + override val dataType: DataType, + canHaveNulls: Boolean, + f: GeneratorFunction) extends DataGenExprBase { + + override def nullable: Boolean = canHaveNulls + + override def inputTypes: Seq[AbstractDataType] = Seq(LongType) + + override def eval(input: InternalRow): Any = { + val rowLoc = new RowLocation(child.eval(input).asInstanceOf[Long]) + f(rowLoc) + } +} + +/** + * Base class for generating a column/sub-column. This holds configuration for the column, + * and handles what is needed to convert it into GeneratorFunction + */ +abstract class DataGen(var conf: ColumnConf, + defaultValueRange: Option[(Any, Any)], + var seedMapping: LocationToSeedMapping = FlatDistribution(), + var nullMapping: LocationToSeedMapping = FlatDistribution(), + var lengthGen: LengthGeneratorFunction = FixedLengthGeneratorFunction(10)) { + protected var userProvidedValueGen: Option[GeneratorFunction] = None + protected var userProvidedNullGen: Option[NullGeneratorFunction] = None + protected var valueRange: Option[(Any, Any)] = defaultValueRange + + /** + * Set a value range for this data gen. + */ + def setValueRange(min: Any, max: Any): DataGen = { + valueRange = Some((min, max)) + this + } + + /** + * Set a custom GeneratorFunction to use for this column. + */ + def setValueGen(f: GeneratorFunction): DataGen = { + userProvidedValueGen = Some(f) + this + } + + /** + * Set a NullGeneratorFunction for this column. This will not be used + * if the column is not nullable. + */ + def setNullGen(f: NullGeneratorFunction): DataGen = { + this.userProvidedNullGen = Some(f) + this + } + + /** + * Set the probability of a null appearing in the output. The probability should be + * 0.0 to 1.0. + */ + def setNullProbability(probability: Double): DataGen = { + this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability)) + this + } + + /** + * Set a specific location to seed mapping for the value generation. + */ + def setSeedMapping(seedMapping: LocationToSeedMapping): DataGen = { + this.seedMapping = seedMapping + this + } + + /** + * Set a specific location to seed mapping for the null generation. + */ + def setNullMapping(nullMapping: LocationToSeedMapping): DataGen = { + this.nullMapping = nullMapping + this + } + + /** + * Set a specific LengthGeneratorFunction to use. This will only be used if + * the datatype needs a length. + */ + def setLengthGen(lengthGen: LengthGeneratorFunction): DataGen = { + this.lengthGen = lengthGen + this + } + + /** + * Set the length generation to be a fixed length. + */ + def setLength(len: Int): DataGen = { + this.lengthGen = FixedLengthGeneratorFunction(len) + this + } + + /** + * Add this column to a specific correlated key group. This should not be + * called directly by users. + */ + def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): DataGen = { + conf = conf.forCorrelatedKeyGroup(keyGroup) + .forSeedRange(minSeed, maxSeed) + this.seedMapping = seedMapping + this + } + + /** + * Set a range of seed values that should be returned by the LocationToSeedMapping + */ + def setSeedRange(min: Long, max: Long): DataGen = { + conf = conf.forSeedRange(min, max) + this + } + + /** + * Get the default value generator for this specific data gen. + */ + protected def getValGen: GeneratorFunction + + /** + * Get the final ready to use GeneratorFunction for the data generator. + */ + def getGen: GeneratorFunction = { + val sm = seedMapping.withColumnConf(conf) + val lg = lengthGen.withLocationToSeedMapping(sm) + var valGen = userProvidedValueGen.getOrElse(getValGen) + .withLocationToSeedMapping(sm) + .withLengthGeneratorFunction(lg) + valueRange.foreach { + case (min, max) => + valGen = valGen.withValueRange(min, max) + } + if (nullable && userProvidedNullGen.isDefined) { + val nullColConf = conf.forNulls + val nm = nullMapping.withColumnConf(nullColConf) + userProvidedNullGen.get + .withWrapped(valGen) + .withLocationToSeedMapping(nm) + } else { + valGen + } + } + + /** + * Get the data type for this column + */ + def dataType: DataType + + /** + * Is this column nullable or not. + */ + def nullable: Boolean = conf.nullable + + /** + * Get a child column for a given name, if it has one. + */ + final def apply(name: String): DataGen = { + get(name).getOrElse{ + throw new IllegalStateException(s"Could not find a child $name for $this") + } + } + + def get(name: String): Option[DataGen] = None +} + +/** + * A Value generator for a single value + * @param v the value to return + */ +case class SingleValGenFunc(v: Any) extends GeneratorFunction { + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): SingleValGenFunc = { + this + } + override def apply(rowLoc: RowLocation): Any = v + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + if (min != max && min != v) { + throw new IllegalArgumentException(s"The only value supported for this range is $v") + } + this + } +} + +/** + * Pick a value out of the list of possible values and return that. + */ +case class EnumValGenFunc(values: Array[_ <: Any], + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + require(values != null && values.length > 0) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): EnumValGenFunc = { + EnumValGenFunc(values, mapping) + } + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException(s"value ranges are not supported for EnumValueGenFunc") + + override def apply(rowLoc: RowLocation): Any = { + val r = DataGen.getRandomFor(rowLoc, mapping) + values(r.nextInt(values.length)) + } +} + +/** + * A value generator for booleans. + */ +case class BooleanGenFunc(mapping: LocationToSeedMapping = null) extends GeneratorFunction { + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): BooleanGenFunc = + BooleanGenFunc(mapping) + + override def apply(rowLoc: RowLocation): Any = + DataGen.getRandomFor(rowLoc, mapping).nextBoolean() + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + val minb = BooleanGen.asBoolean(min) + val maxb = BooleanGen.asBoolean(max) + if (minb == maxb) { + SingleValGenFunc(minb) + } else { + this + } + } +} + +object BooleanGen { + def asBoolean(a: Any): Boolean = a match { + case b: Boolean => b + case other => + throw new IllegalArgumentException(s"a boolean value range only supports " + + s"boolean values found $other") + } +} + +class BooleanGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + + override def dataType: DataType = BooleanType + + override protected def getValGen: GeneratorFunction = BooleanGenFunc() +} + +/** + * A value generator for Bytes + */ +case class ByteGenFunc(mapping: LocationToSeedMapping = null, + min: Byte = Byte.MinValue, + max: Byte = Byte.MaxValue) extends GeneratorFunction { + + private lazy val valueRemapping: Long => Byte = ByteGen.remapRangeFunc(min, max) + + override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping)) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): ByteGenFunc = + ByteGenFunc(mapping, min, max) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var bmin = ByteGen.asByte(min) + var bmax = ByteGen.asByte(max) + if (bmin > bmax) { + val tmp = bmin + bmin = bmax + bmax = tmp + } + + if (bmin == bmax) { + SingleValGenFunc(bmin) + } else { + ByteGenFunc(mapping, bmin, bmax) + } + } +} + +object ByteGen { + def asByte(a: Any): Byte = a match { + case n: Byte => n + case n: Short if n == n.toByte => n.toByte + case n: Int if n == n.toByte => n.toByte + case n: Long if n == n.toByte => n.toByte + case other => + throw new IllegalArgumentException(s"a byte value range only supports " + + s"byte values found $other") + } + + def remapRangeFunc(min: Byte, max: Byte): Long => Byte = { + val wrapped = DataGen.remapRangeFunc(min, max) + n => wrapped(n).toByte + } +} + +class ByteGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + override def getValGen: GeneratorFunction = ByteGenFunc() + override def dataType: DataType = ByteType +} + +/** + * A value generator for Shorts + */ +case class ShortGenFunc(mapping: LocationToSeedMapping = null, + min: Short = Short.MinValue, + max: Short = Short.MaxValue) extends GeneratorFunction { + + private lazy val valueRemapping: Long => Short = + ShortGen.remapRangeFunc(min, max) + + override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping)) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): ShortGenFunc = + ShortGenFunc(mapping, min, max) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var smin = ShortGen.asShort(min) + var smax = ShortGen.asShort(max) + if (smin > smax) { + val tmp = smin + smin = smax + smax = tmp + } + + if (smin == smax) { + SingleValGenFunc(smin) + } else { + ShortGenFunc(mapping, smin, smax) + } + } +} + +object ShortGen { + def asShort(a: Any): Short = a match { + case n: Byte => n.toShort + case n: Short => n + case n: Int if n == n.toShort => n.toShort + case n: Long if n == n.toShort => n.toShort + case other => + throw new IllegalArgumentException(s"a short value range only supports " + + s"short values found $other") + } + + def remapRangeFunc(min: Short, max: Short): Long => Short = { + val wrapped = DataGen.remapRangeFunc(min, max) + n => wrapped(n).toShort + } +} + +class ShortGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + override def getValGen: GeneratorFunction = ShortGenFunc() + + override def dataType: DataType = ShortType +} + +/** + * A value generator for Ints + */ +case class IntGenFunc(mapping: LocationToSeedMapping = null, + min: Int = Int.MinValue, + max: Int = Int.MaxValue) extends GeneratorFunction { + + private lazy val valueRemapping: Long => Int = IntGen.remapRangeFunc(min, max) + + override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping)) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): IntGenFunc = + IntGenFunc(mapping, min, max) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var imin = IntGen.asInt(min) + var imax = IntGen.asInt(max) + if (imin > imax) { + val tmp = imin + imin = imax + imax = tmp + } + + if (imin == imax) { + SingleValGenFunc(imin) + } else { + IntGenFunc(mapping, imin, imax) + } + } +} + +object IntGen { + def asInt(a: Any): Int = a match { + case n: Byte => n.toInt + case n: Short => n.toInt + case n: Int => n + case n: Long if n == n.toInt => n.toInt + case other => + throw new IllegalArgumentException(s"a int value range only supports " + + s"int values found $other") + } + + def remapRangeFunc(min: Int, max: Int): Long => Int = { + val wrapped = DataGen.remapRangeFunc(min, max) + n => wrapped(n).toInt + } +} + +class IntGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + override def getValGen: GeneratorFunction = IntGenFunc() + + override def dataType: DataType = IntegerType +} + +/** + * A value generator for Longs + */ +case class LongGenFunc(mapping: LocationToSeedMapping = null, + min: Long = Long.MinValue, + max: Long = Long.MaxValue) extends GeneratorFunction { + + protected lazy val valueRemapping: Long => Long = LongGen.remapRangeFunc(min, max) + + override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping)) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): LongGenFunc = + LongGenFunc(mapping, min, max) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var lmin = LongGen.asLong(min) + var lmax = LongGen.asLong(max) + if (lmin > lmax) { + val tmp = lmin + lmin = lmax + lmax = tmp + } + + if (lmin == lmax) { + SingleValGenFunc(lmin) + } else { + LongGenFunc(mapping, lmin, lmax) + } + } +} + +object LongGen { + def asLong(a: Any): Long = a match { + case n: Byte => n.toLong + case n: Short => n.toLong + case n: Int => n.toLong + case n: Long => n + case other => + throw new IllegalArgumentException(s"a long value range only supports " + + s"long values found $other") + } + + def remapRangeFunc(min: Long, max: Long): Long => Long = + DataGen.remapRangeFunc(min, max) +} + +class LongGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + override def getValGen: GeneratorFunction = LongGenFunc() + + override def dataType: DataType = LongType +} + +case class Decimal32GenFunc( + precision: Int, + scale: Int, + unscaledMin: Int, + unscaledMax: Int, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + private lazy val valueRemapping: Long => Int = IntGen.remapRangeFunc(unscaledMin, unscaledMax) + + override def apply(rowLoc: RowLocation): Any = + Decimal(valueRemapping(DataGen.nextLong(rowLoc, mapping)), precision, scale) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + Decimal32GenFunc(precision, scale, unscaledMin, unscaledMax, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var imin = DecimalGen.asUnscaledInt(min, precision, scale) + var imax = DecimalGen.asUnscaledInt(max, precision, scale) + if (imin > imax) { + val tmp = imin + imin = imax + imax = tmp + } + + if (imin == imax) { + SingleValGenFunc(Decimal(imin, precision, scale)) + } else { + Decimal32GenFunc(precision, scale, imin, imax, mapping) + } + } +} + +case class Decimal64GenFunc( + precision: Int, + scale: Int, + unscaledMin: Long, + unscaledMax: Long, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + private lazy val valueRemapping: Long => Long = LongGen.remapRangeFunc(unscaledMin, unscaledMax) + + override def apply(rowLoc: RowLocation): Any = + Decimal(valueRemapping(DataGen.nextLong(rowLoc, mapping)), precision, scale) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + Decimal64GenFunc(precision, scale, unscaledMin, unscaledMax, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var lmin = DecimalGen.asUnscaledLong(min, precision, scale) + var lmax = DecimalGen.asUnscaledLong(max, precision, scale) + if (lmin > lmax) { + val tmp = lmin + lmin = lmax + lmax = tmp + } + + if (lmin == lmax) { + SingleValGenFunc(Decimal(lmin, precision, scale)) + } else { + Decimal64GenFunc(precision, scale, lmin, lmax, mapping) + } + } +} + +case class DecimalGenFunc( + precision: Int, + scale: Int, + unscaledMin: BigInt, + unscaledMax: BigInt, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + private lazy val valueRemapping: Long => BigInt = + DecimalGen.remapRangeFunc(unscaledMin, unscaledMax) + + override def apply(rowLoc: RowLocation): Any = { + val bi = valueRemapping(DataGen.nextLong(rowLoc, mapping)).bigInteger + Decimal(new JavaBigDecimal(bi, scale)) + } + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + DecimalGenFunc(precision, scale, unscaledMin, unscaledMax, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var lmin = DecimalGen.asUnscaled(min, precision, scale) + var lmax = DecimalGen.asUnscaled(max, precision, scale) + if (lmin > lmax) { + val tmp = lmin + lmin = lmax + lmax = tmp + } + + if (lmin == lmax) { + SingleValGenFunc(Decimal(new JavaBigDecimal(lmin.bigInteger, scale))) + } else { + DecimalGenFunc(precision, scale, lmin, lmax, mapping) + } + } +} + +object DecimalGen { + def toUnscaledInt(d: JavaBigDecimal, precision: Int, scale: Int): Int = { + val tmp = d.setScale(scale) + require(tmp.precision() <= precision, "The value is not in the supported precision range") + tmp.unscaledValue().intValueExact() + } + + def asUnscaledInt(a: Any, precision: Int, scale: Int): Int = a match { + case d: BigDecimal => + val tmp = d.setScale(scale) + toUnscaledInt(tmp.bigDecimal, precision, scale) + case d: JavaBigDecimal => + toUnscaledInt(d, precision, scale) + case d: Decimal => + toUnscaledInt(d.toJavaBigDecimal, precision, scale) + case n: Byte => + toUnscaledInt(new JavaBigDecimal(n), precision, scale) + case n: Short => + toUnscaledInt(new JavaBigDecimal(n), precision, scale) + case n: Int => + toUnscaledInt(new JavaBigDecimal(n), precision, scale) + case n: Long => + toUnscaledInt(new JavaBigDecimal(n), precision, scale) + case n: Float => + toUnscaledInt(new JavaBigDecimal(n), precision, scale) + case n: Double => + toUnscaledInt(new JavaBigDecimal(n), precision, scale) + case other => + throw new IllegalArgumentException(s"a decimal 32 value range only supports " + + s"decimal values found $other") + } + + def toUnscaledLong(d: JavaBigDecimal, precision: Int, scale: Int): Long = { + val tmp = d.setScale(scale) + require(tmp.precision() <= precision, "The value is not in the supported precision range") + tmp.unscaledValue().longValueExact() + } + + def asUnscaledLong(a: Any, precision: Int, scale: Int): Long = a match { + case d: BigDecimal => + val tmp = d.setScale(scale) + toUnscaledLong(tmp.bigDecimal, precision, scale) + case d: JavaBigDecimal => + toUnscaledLong(d, precision, scale) + case d: Decimal => + toUnscaledLong(d.toJavaBigDecimal, precision, scale) + case n: Byte => + toUnscaledLong(new JavaBigDecimal(n), precision, scale) + case n: Short => + toUnscaledLong(new JavaBigDecimal(n), precision, scale) + case n: Int => + toUnscaledLong(new JavaBigDecimal(n), precision, scale) + case n: Long => + toUnscaledLong(new JavaBigDecimal(n), precision, scale) + case n: Float => + toUnscaledLong(new JavaBigDecimal(n), precision, scale) + case n: Double => + toUnscaledLong(new JavaBigDecimal(n), precision, scale) + case other => + throw new IllegalArgumentException(s"a decimal 64 value range only supports " + + s"decimal values found $other") + } + + def toUnscaled(d: JavaBigDecimal, precision: Int, scale: Int): BigInt = { + val tmp = d.setScale(scale) + require(tmp.precision() <= precision, "The value is not in the supported precision range") + tmp.unscaledValue() + } + + def asUnscaled(a: Any, precision: Int, scale: Int): BigInt = a match { + case d: BigDecimal => + val tmp = d.setScale(scale) + toUnscaled(tmp.bigDecimal, precision, scale) + case d: JavaBigDecimal => + toUnscaled(d, precision, scale) + case d: Decimal => + toUnscaled(d.toJavaBigDecimal, precision, scale) + case n: Byte => + toUnscaled(new JavaBigDecimal(n), precision, scale) + case n: Short => + toUnscaled(new JavaBigDecimal(n), precision, scale) + case n: Int => + toUnscaled(new JavaBigDecimal(n), precision, scale) + case n: Long => + toUnscaled(new JavaBigDecimal(n), precision, scale) + case n: Float => + toUnscaled(new JavaBigDecimal(n), precision, scale) + case n: Double => + toUnscaled(new JavaBigDecimal(n), precision, scale) + case other => + throw new IllegalArgumentException(s"a decimal value range only supports " + + s"decimal values found $other") + } + + def genMaxUnscaledInt(precision: Int): Int = { + require(precision >= 0 && precision <= Decimal.MAX_INT_DIGITS) + genMaxUnscaled(precision).toInt + } + + def genMaxUnscaledLong(precision: Int): Long = { + require(precision > Decimal.MAX_INT_DIGITS && precision <= Decimal.MAX_LONG_DIGITS) + genMaxUnscaled(precision).toLong + } + + def genMaxUnscaled(precision: Int): BigInt = + BigInt(10).pow(precision) - 1 + + def remapRangeFunc(minVal: BigInt, maxVal: BigInt): Long => BigInt = { + // We generate numbers between minVal and maxVal + 1, and then take the floor so that + // all numbers get an even chance. + val minValBD = BigDecimal(minVal) + val scaleFactor = (BigDecimal(maxVal) + 1 - minValBD) / + (BigDecimal(Long.MaxValue) - Long.MinValue) + val minLong = BigDecimal(Long.MinValue) + n => { + (((n - minLong) * scaleFactor) + minValBD).setScale(0, RoundingMode.FLOOR).toBigInt() + } + } +} + +class DecimalGen(dt: DecimalType, + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + + override def dataType: DataType = dt + + override protected def getValGen: GeneratorFunction = + if (dt.precision <= Decimal.MAX_INT_DIGITS) { + val max = DecimalGen.genMaxUnscaledInt(dt.precision) + Decimal32GenFunc(dt.precision, dt.scale, max, -max) + } else if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + val max = DecimalGen.genMaxUnscaledLong(dt.precision) + Decimal64GenFunc(dt.precision, dt.scale, max, -max) + } else { + val max = DecimalGen.genMaxUnscaled(dt.precision) + DecimalGenFunc(dt.precision, dt.scale, -max, max) + } +} + +/** + * A value generator for Timestamps + */ +case class TimestampGenFunc(mapping: LocationToSeedMapping = null, + min: Long = Long.MinValue, + max: Long = Long.MaxValue) extends GeneratorFunction { + + private lazy val valueRemapping: Long => Long = LongGen.remapRangeFunc(min, max) + + override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping)) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): TimestampGenFunc = + TimestampGenFunc(mapping, min, max) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var lmin = TimestampGen.asLong(min) + var lmax = TimestampGen.asLong(max) + if (lmin > lmax) { + val tmp = lmin + lmin = lmax + lmax = tmp + } + + if (lmin == lmax) { + SingleValGenFunc(lmin) + } else { + TimestampGenFunc(mapping, lmin, lmax) + } + } +} + +object TimestampGen { + def asLong(a: Any): Long = a match { + case n: Byte => n.toLong + case n: Short => n.toLong + case n: Int => n.toLong + case n: Long => n + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + case other => + throw new IllegalArgumentException(s"a timestamp value range only supports " + + s"timestamp or integral values found $other") + } +} + +class TimestampGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + override protected def getValGen: GeneratorFunction = TimestampGenFunc() + + override def dataType: DataType = TimestampType +} + +/** + * A value generator for Dates + */ +case class DateGenFunc(mapping: LocationToSeedMapping = null, + min: Int = Int.MinValue, + max: Int = Int.MaxValue) extends GeneratorFunction { + + private lazy val valueRemapping: Long => Int = IntGen.remapRangeFunc(min, max) + + override def apply(rowLoc: RowLocation): Any = valueRemapping(DataGen.nextLong(rowLoc, mapping)) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): DateGenFunc = + DateGenFunc(mapping, min, max) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = { + var imin = DateGen.asInt(min) + var imax = DateGen.asInt(max) + if (imin > imax) { + val tmp = imin + imin = imax + imax = tmp + } + + if (imin == imax) { + SingleValGenFunc(imin) + } else { + DateGenFunc(mapping, imin, imax) + } + } +} + +object DateGen { + def asInt(a: Any): Int = a match { + case n: Byte => n.toInt + case n: Short => n.toInt + case n: Int => n + case n: Long if n <= Int.MaxValue && n >= Int.MinValue => n.toInt + case ld: LocalDate => ld.toEpochDay.toInt + case d: Date => DateTimeUtils.fromJavaDate(d) + case other => + throw new IllegalArgumentException(s"a date value range only supports " + + s"date or int values found $other") + } +} + +class DateGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + override protected def getValGen: GeneratorFunction = DateGenFunc() + + override def dataType: DataType = DateType +} + +/** + * A value generator for Doubles + */ +case class DoubleGenFunc(mapping: LocationToSeedMapping = null) extends GeneratorFunction { + override def apply(rowLoc: RowLocation): Any = + java.lang.Double.longBitsToDouble(DataGen.nextLong(rowLoc, mapping)) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): DoubleGenFunc = + DoubleGenFunc(mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalStateException("value ranges are not supported for Double yet") +} + +class DoubleGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + + override def dataType: DataType = DoubleType + + override protected def getValGen: GeneratorFunction = DoubleGenFunc() +} + +/** + * A value generator for Floats + */ +case class FloatGenFunc(mapping: LocationToSeedMapping = null) extends GeneratorFunction { + override def apply(rowLoc: RowLocation): Any = + java.lang.Float.intBitsToFloat(DataGen.getRandomFor(mapping(rowLoc)).nextInt()) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): FloatGenFunc = + FloatGenFunc(mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalStateException("value ranges are not supported for Float yet") +} + +class FloatGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + + override def dataType: DataType = FloatType + + override protected def getValGen: GeneratorFunction = FloatGenFunc() +} + +case class ASCIIGenFunc( + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val r = DataGen.getRandomFor(rowLoc, mapping) + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + // Value range is 32 (Space) to 126 (~) + buffer(at) = (r.nextInt(126 - 31) + 32).toByte + at += 1 + } + UTF8String.fromBytes(buffer, 0, len) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = + ASCIIGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + ASCIIGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for strings") +} + +class StringGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + + override def dataType: DataType = StringType + + override protected def getValGen: GeneratorFunction = ASCIIGenFunc() +} + +case class StructGenFunc(childGens: Array[GeneratorFunction]) extends GeneratorFunction { + override def apply(rowLoc: RowLocation): Any = { + // The row location does not change for a struct + val data = childGens.map(_.apply(rowLoc)) + InternalRow(data: _*) + } + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + this + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported by structs") +} + +class StructGen(val children: Seq[(String, DataGen)], + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + + private lazy val dt = { + val childrenFields = children.map { + case (name, gen) => + StructField(name, gen.dataType) + } + StructType(childrenFields) + } + override def dataType: DataType = dt + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): DataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + children.foreach { + case (_, gen) => + gen.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + } + this + } + + override def get(name: String): Option[DataGen] = + children.collectFirst { + case (childName, dataGen) if childName.equalsIgnoreCase(name) => dataGen + } + + override protected def getValGen: GeneratorFunction = { + val childGens = children.map(c => c._2.getGen).toArray + StructGenFunc(childGens) + } +} + +case class ArrayGenFunc( + child: GeneratorFunction, + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val data = new Array[Any](len) + val childRowLoc = rowLoc.withNewChild() + var i = 0 + while (i < len) { + childRowLoc.setLastChildIndex(i) + data(i) = child(childRowLoc) + i += 1 + } + ArrayData.toArrayData(data) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = + ArrayGenFunc(child, lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + ArrayGenFunc(child, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for arrays") +} + +class ArrayGen(child: DataGen, + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)]) + extends DataGen(conf, defaultValueRange) { + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): DataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + child.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + this + } + + override protected def getValGen: GeneratorFunction = ArrayGenFunc(child.getGen) + + override def dataType: DataType = ArrayType(child.dataType, containsNull = child.nullable) + + override def get(name: String): Option[DataGen] = { + if ("data".equalsIgnoreCase(name) || "child".equalsIgnoreCase(name)) { + Some(child) + } else { + None + } + } +} + +object ColumnGen { + private def genInternal(rowNumber: Column, + dataType: DataType, + nullable: Boolean, + gen: GeneratorFunction): Column = { + Column(DataGenExpr(rowNumber.expr, dataType, nullable, gen)) + } +} + +/** + * Generates a top level column in a data set. + */ +class ColumnGen(val dataGen: DataGen) { + def setCorrelatedKeyGroup(kg: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): ColumnGen = { + dataGen.setCorrelatedKeyGroup(kg, minSeed, maxSeed, seedMapping) + this + } + + def setSeedRange(min: Long, max: Long): ColumnGen = { + dataGen.setSeedRange(min, max) + this + } + + def setSeedMapping(seedMapping: LocationToSeedMapping): ColumnGen = { + dataGen.setSeedMapping(seedMapping) + this + } + + def setNullSeedMapping(seedMapping: LocationToSeedMapping): ColumnGen = { + dataGen.setNullMapping(seedMapping) + this + } + + def setValueRange(min: Any, max: Any): ColumnGen = { + dataGen.setValueRange(min, max) + this + } + + def setNullProbability(probability: Double): ColumnGen = { + dataGen.setNullProbability(probability) + this + } + + def setNullGen(f: NullGeneratorFunction): ColumnGen = { + dataGen.setNullGen(f) + this + } + + def setValueGen(f: GeneratorFunction): ColumnGen = { + dataGen.setValueGen(f) + this + } + + def setLengthGen(lengthGen: LengthGeneratorFunction): ColumnGen = { + dataGen.setLengthGen(lengthGen) + this + } + + def setLength(len: Int): ColumnGen = { + dataGen.setLength(len) + this + } + + final def apply(name: String): DataGen = { + get(name).getOrElse { + throw new IllegalArgumentException(s"$name not a child of $this") + } + } + + def get(name: String): Option[DataGen] = dataGen.get(name) + + def gen(rowNumber: Column): Column = { + ColumnGen.genInternal(rowNumber, dataGen.dataType, dataGen.nullable, dataGen.getGen) + } +} + +sealed trait KeyGroupType + +/** + * A key group where all of the columns and sub-columns use the key group id instead of + * the column number when calculating the seed for the desired value. This makes all of + * the generator functions to compute the exact same seed and produce correlated results. + * For this to work the columns that will be joined to be configured with the same generators + * and the same value ranges (if any). But the distribution of the seeds and the seed range + * do not have to correspond. + * @param id the id of the key group, that replaces the column number. + * @param minSeed the min seed value for this key group + * @param maxSeed the max seed value for this key group + */ +case class CorrelatedKeyGroup(id: Long, minSeed: Long, maxSeed: Long) extends KeyGroupType + +/** + * A key group where the seed range is calculated based on the number of columns to + * get approximately the desired number of unique combinations. This will not be exact + * and does not work with nested types, or if the cardinality of the columns involved + * is not large enough to support the desired key range. Much of this can be fixed in + * the future. + * + * If you want to get partially overlapping groups you can increase the number of + * combinations, but again this is approximate so it might not really do exactly what you want. + * @param startSeed the seed that all combinatorial groups should start with + * @param combinations the desired number of unique combinations. + */ +case class CombinatorialKeyGroup(startSeed: Long, combinations: Long) extends KeyGroupType + +/** + * Used to generate a table in a database + */ +class TableGen(val columns: Seq[(String, ColumnGen)], numRows: Long) { + /** + * A Key group allows you to setup a multi-column key for things like a join where you + * want the keys to be generated as a group. + * @param names the names of the columns in the group. A column can only be a part of a + * single group. + * @param mapping the distribution/mapping for this key group + * @param groupType the type of key group that this is for. + * @return this for chaining + */ + def configureKeyGroup( + names: Seq[String], + groupType: KeyGroupType, + mapping: LocationToSeedMapping): TableGen = { + groupType match { + case CorrelatedKeyGroup(id, minSeed, maxSeed) => + names.foreach { name => + val col = apply(name) + col.setCorrelatedKeyGroup(id, minSeed, maxSeed, mapping) + } + case CombinatorialKeyGroup(startSeed, combinations) => + var numberOfCombinationsRemaining = combinations + var numberOfColumnsRemaining = names.length + names.foreach { name => + val choices = math.pow(numberOfCombinationsRemaining.toDouble, + 1.0 / numberOfColumnsRemaining).toLong + if (choices <= 0) { + throw new IllegalArgumentException("Could not find a way to split up " + + s"${names.length} columns to produce $combinations unique combinations of values") + } + val column = apply(name) + column.setSeedRange(startSeed, startSeed + choices - 1) + numberOfColumnsRemaining -= 1 + numberOfCombinationsRemaining = numberOfCombinationsRemaining / choices + } + } + this + } + + /** + * Convert this table into a `DataFrame` that can be + * written out or used directly. Writing it out to parquet + * or ORC is probably better because you will not run into + * issues with the generation happening inline and performance + * being bad because this is not on the GPU. + * @param spark the session to use. + * @param numParts the number of parts to use (if > 0) + */ + def toDF(spark: SparkSession, numParts: Int = 0): DataFrame = { + val id = col("id") + val allGens = columns.map { + case (name, childGen) => + childGen.gen(id).alias(name) + } + + val range = if (numParts > 0) { + spark.range(0, numRows, 1, numParts) + } else { + spark.range(numRows) + } + range.select(allGens: _*) + } + + /** + * Get a ColumnGen for a named column. + * @param name the name of the column to look for + * @return the corresponding column gen + */ + def apply(name: String): ColumnGen = { + get(name).getOrElse { + throw new IllegalArgumentException(s"$name not found: ${columns.map(_._1).mkString(" ")}") + } + } + + def get(name: String): Option[ColumnGen] = + columns.collectFirst { + case (childName, colGen) if childName.equalsIgnoreCase(name) => colGen + } +} + +/** + * Provides a way to map a DataType to a specific DataGen instance. + */ +trait TypeMapping { + /** + * Does this TypeMapping support the give data type. + * @return true if it does, else false + */ + def canMap(dt: DataType, subTypeMapping: TypeMapping): Boolean + + /** + * If canMap returned true, then do the mapping + * @return the DataGen along with the last ColumnConf that was used by a column + */ + def map(dt: DataType, + conf: ColumnConf, + defaultRanges: mutable.HashMap[DataType, (Any, Any)], + subTypeMapping: TypeMapping): (DataGen, ColumnConf) +} + +object DefaultTypeMapping extends TypeMapping { + override def canMap(dataType: DataType, subTypeMapping: TypeMapping): Boolean = dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case _: DecimalType => true + case FloatType => true + case DoubleType => true + case StringType => true + case TimestampType => true + case DateType => true + case st: StructType => + st.forall(child => subTypeMapping.canMap(child.dataType, subTypeMapping)) + case at: ArrayType => + subTypeMapping.canMap(at.elementType, subTypeMapping) + case _ => false + } + + override def map(dataType: DataType, + conf: ColumnConf, + defaultRanges: mutable.HashMap[DataType, (Any, Any)], + subTypeMapping: TypeMapping): (DataGen, ColumnConf) = dataType match { + case BooleanType => + (new BooleanGen(conf, defaultRanges.get(dataType)), conf) + case ByteType => + (new ByteGen(conf, defaultRanges.get(dataType)), conf) + case ShortType => + (new ShortGen(conf, defaultRanges.get(dataType)), conf) + case IntegerType => + (new IntGen(conf, defaultRanges.get(dataType)), conf) + case LongType => + (new LongGen(conf, defaultRanges.get(dataType)), conf) + case dt: DecimalType => + (new DecimalGen(dt, conf, defaultRanges.get(dataType)), conf) + case FloatType => + (new FloatGen(conf, defaultRanges.get(dataType)), conf) + case DoubleType => + (new DoubleGen(conf, defaultRanges.get(dataType)), conf) + case StringType => + (new StringGen(conf, defaultRanges.get(dataType)), conf) + case TimestampType => + (new TimestampGen(conf, defaultRanges.get(dataType)), conf) + case DateType => + (new DateGen(conf, defaultRanges.get(dataType)), conf) + case st: StructType => + var tmpConf = conf + val fields = st.map { sf => + tmpConf = tmpConf.forNextColumn(sf.nullable) + val genNCol = subTypeMapping.map(sf.dataType, tmpConf, defaultRanges, subTypeMapping) + tmpConf = genNCol._2 + (sf.name, genNCol._1) + } + (new StructGen(fields, conf, defaultRanges.get(dataType)), tmpConf) + case at: ArrayType => + val childConf = conf.forNextColumn(at.containsNull) + val child = subTypeMapping.map(at.elementType, childConf, defaultRanges, subTypeMapping) + (new ArrayGen(child._1, conf, defaultRanges.get(dataType)), child._2) + case other => + throw new IllegalArgumentException(s"$other is not a supported type yet") + } +} + +case class OrderedTypeMapping(ordered: Array[TypeMapping]) extends TypeMapping { + override def canMap(dt: DataType, subTypeMapping: TypeMapping): Boolean = { + ordered.foreach { mapping => + if (mapping.canMap(dt, subTypeMapping)) { + return true + } + } + false + } + + override def map(dt: DataType, + conf: ColumnConf, + defaultRanges: mutable.HashMap[DataType, (Any, Any)], + subTypeMapping: TypeMapping): (DataGen, ColumnConf) = { + ordered.foreach { mapping => + if (mapping.canMap(dt, subTypeMapping)) { + return mapping.map(dt, conf, defaultRanges, subTypeMapping) + } + } + // This should not be reachable + throw new IllegalStateException(s"$dt is not currently supported") + } +} + +object DBGen { + def empty: DBGen = new DBGen() + def apply(): DBGen = new DBGen() + + private def dtToTopLevelGen( + st: StructType, + tableId: Int, + defaultRanges: mutable.HashMap[DataType, (Any, Any)], + numRows: Long, + mapping: OrderedTypeMapping): Seq[(String, ColumnGen)] = { + // a bit of a hack with the column num so that we update it before each time... + var conf = ColumnConf(ColumnLocation(tableId, -1), true, numRows) + st.toArray.map { sf => + if (!mapping.canMap(sf.dataType, mapping)) { + throw new IllegalArgumentException(s"$sf is not supported at this time") + } + conf = conf.forNextColumn(sf.nullable) + val tmp = mapping.map(sf.dataType, conf, defaultRanges, mapping) + conf = tmp._2 + (sf.name, new ColumnGen(tmp._1)) + } + } +} + +/** + * Set up the schema for different tables and the relationship between various keys/columns + * in the tables. + */ +class DBGen { + private var tableId = 0 + private val tables = mutable.HashMap.empty[String, TableGen] + private val defaultRanges = mutable.HashMap.empty[DataType, (Any, Any)] + private val mappings = ArrayBuffer[TypeMapping](DefaultTypeMapping) + + /** + * Set a default value range for all generators of a given type. Note that this only impacts + * tables that have not been added yet. Some generators don't support value ranges setting + * this for a type that does not support it will result in an error when creating one. + * Some generators can be configured to ignore this too, like if you pass in your own + * function for data generation. In those cases this may be ignored. + */ + def setDefaultValueRange(dt: DataType, min: Any, max: Any): DBGen = { + defaultRanges.put(dt, (min, max)) + this + } + + /** + * Add a new table with a given input + * @param name the name of the table (must be unique) + * @param columns the generators that will be used for this table + * @return the TableGen that was added + */ + private def addTable(name: String, + columns: Seq[(String, ColumnGen)], + numRows: Long): TableGen = { + val lowerName = name.toLowerCase + if (lowerName.contains(".")) { + // Not sure if there are other forbidden characters, but we can check on that later + throw new IllegalArgumentException("Name cannot contain '.' character") + } + if (tables.contains(lowerName)) { + throw new IllegalArgumentException("Cannot add duplicate tables (even if case is different)") + } + val ret = new TableGen(columns, numRows) + tables.put(lowerName, ret) + ret + } + + /** + * Add a new table with a given type + * @param name the name of the table (must be unique) + * @param st the type for this table. + * @param numRows the number of rows for this table + * @return the TableGen that was just added + */ + def addTable(name: String, + st: StructType, + numRows: Long): TableGen = { + val localTableId = tableId + tableId += 1 + val mapping = OrderedTypeMapping(mappings.toArray) + val sg = DBGen.dtToTopLevelGen(st, localTableId, defaultRanges, numRows, mapping) + addTable(name, sg, numRows) + } + + /** + * Add a new table with a type defined by the DDL + * @param name the name of the table (must be unique) + * @param ddl the DDL that describes the type for the table. + * @param numRows the number of rows for this table + * @return the TableGen that was just created + */ + def addTable(name: String, + ddl: String, + numRows: Long): TableGen = { + val localTableId = tableId + tableId += 1 + val mapping = OrderedTypeMapping(mappings.toArray) + val sg = DBGen.dtToTopLevelGen( + DataType.fromDDL(ddl).asInstanceOf[StructType], localTableId, defaultRanges, numRows, mapping) + addTable(name, sg, numRows) + } + + /** + * Get a TableGen by name + * @param name the name of the table to look for + * @return the corresponding TableGen + */ + def apply(name: String): TableGen = tables(name.toLowerCase) + + def get(name: String): Option[TableGen] = tables.get(name.toLowerCase) + + /** + * Names of the tables. + */ + def tableNames: Seq[String] = tables.keys.toSeq + + /** + * Get an immutable map out of this. + */ + def toMap: Map[String, TableGen] = tables.toMap + + /** + * Convert all of the tables into dataframes + * @param spark the session to use for the conversion + * @param numParts the number of parts (tasks) to use. <= 0 uses the same number of tasks + * as the cluster has. + * @return a Map of the name of the table to the dataframe. + */ + def toDF(spark: SparkSession, numParts: Int = 0): Map[String, DataFrame] = + toMap.map { + case (name, gen) => (name, gen.toDF(spark, numParts)) + } + + /** + * Take all of the tables and create or replace temp views with them using the given name. + * The data will be generated inline as the query runs. + * @param spark the session to use + * @param numParts the number of parts to use, if > 0 + */ + def createOrReplaceTempViews(spark: SparkSession, numParts: Int = 0): Unit = + toDF(spark, numParts).foreach { + case (name, df) => df.createOrReplaceTempView(name) + } + + /** + * Write all of the tables out as parquet under path/table_name and overwrite anything that is + * already there. + * @param spark the session to use + * @param path the base path to write the data under + * @param numParts the number of parts to use, if > 0 + * @param overwrite if true will overwrite existing data + */ + def writeParquet(spark: SparkSession, + path: String, + numParts: Int = 0, + overwrite: Boolean = false): Unit = { + toDF(spark, numParts).foreach { + case (name, df) => + val subPath = path + "/" + name + var writer = df.write + if (overwrite) { + writer = writer.mode("overwrite") + } + writer.parquet(subPath) + } + } + + /** + * Create or replace temp views for all of the tables assuming that they were already written + * out as parquet under path/table_name. + * @param spark the session to use + * @param path the base path to read the data from + */ + def createOrReplaceTempViewsFromParquet(spark: SparkSession, path: String): Unit = { + tables.foreach { + case (name, _) => + val subPath = path + "/" + name + spark.read.parquet(subPath).createOrReplaceTempView(name) + } + } + + /** + * Write all of the tables out as parquet under path/table_name and then create or replace temp + * views for each of the tables using the given name. + * + * @param spark the session to use + * @param path the base path to write the data under + * @param numParts the number of parts to use, if > 0 + * @param overwrite if true will overwrite existing data + */ + def writeParquetAndReplaceTempViews(spark: SparkSession, + path: String, + numParts: Int = 0, + overwrite: Boolean = false): Unit = { + writeParquet(spark, path, numParts, overwrite) + createOrReplaceTempViewsFromParquet(spark, path) + } + + /** + * Write all of the tables out as orc under path/table_name and overwrite anything that is + * already there. + * + * @param spark the session to use + * @param path the base path to write the data under + * @param numParts the number of parts to use, if > 0 + * @param overwrite if true will overwrite existing data + */ + def writeOrc(spark: SparkSession, + path: String, + numParts: Int = 0, + overwrite: Boolean = false): Unit = { + toDF(spark, numParts).foreach { + case (name, df) => + val subPath = path + "/" + name + var writer = df.write + if (overwrite) { + writer = writer.mode("overwrite") + } + writer.orc(subPath) + } + } + + /** + * Create or replace temp views for all of the tables assuming that they were already written + * out as orc under path/table_name. + * @param spark the session to use + * @param path the base path to read the data from + */ + def createOrReplaceTempViewsFromOrc(spark: SparkSession, path: String): Unit = { + tables.foreach { + case (name, _) => + val subPath = path + "/" + name + spark.read.orc(subPath).createOrReplaceTempView(name) + } + } + + /** + * Write all of the tables out as orc under path/table_name and then create or replace temp + * views for each of the tables using the given name. + * + * @param spark the session to use + * @param path the base path to write the data under + * @param numParts the number of parts to use, if > 0 + * @param overwrite if true will overwrite existing data + */ + def writeOrcAndReplaceTempViews(spark: SparkSession, + path: String, + numParts: Int = 0, + overwrite: Boolean = false): Unit = { + writeOrc(spark, path, numParts, overwrite) + createOrReplaceTempViewsFromOrc(spark, path) + } + + /** + * Add a new user controlled type mapping. This allows + * the user to totally override the handling for any or all types. + * @param mapping the new mapping to add with highest priority. + */ + def addTypeMapping(mapping: TypeMapping): Unit = { + // Insert this mapping in front of the others + mappings.insert(0, mapping) + } +} diff --git a/integration_tests/src/main/spark311/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala b/integration_tests/src/main/spark311/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala new file mode 100644 index 000000000000..d50008f7fb72 --- /dev/null +++ b/integration_tests/src/main/spark311/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.tests.datagen + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback + +trait DataGenExprBase extends UnaryExpression with ExpectsInputTypes with CodegenFallback diff --git a/integration_tests/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala b/integration_tests/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala new file mode 100644 index 000000000000..3d2e03c50f1e --- /dev/null +++ b/integration_tests/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprBase.scala @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "321db"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "340"} +{"spark": "341"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.tests.datagen + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback + +trait DataGenExprBase extends UnaryExpression with ExpectsInputTypes with CodegenFallback { + override def withNewChildInternal(newChild: Expression): Expression = + legacyWithNewChildren(Seq(newChild)) +} diff --git a/pom.xml b/pom.xml index d12c535e44cb..c5ee3efecade 100644 --- a/pom.xml +++ b/pom.xml @@ -524,7 +524,6 @@ 11 11 11 - true @@ -534,7 +533,6 @@ 17 17 17 - true diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index a724afb45b84..5705e38f8c93 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1128,7 +1128,7 @@ object GpuOverrides extends Logging { }), expr[SecondsToTimestamp]( "Converts the number of seconds from unix epoch to a timestamp", - ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, TypeSig.gpuNumeric, TypeSig.cpuNumeric), (a, conf, p, r) => new UnaryExprMeta[SecondsToTimestamp](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = @@ -1136,7 +1136,7 @@ object GpuOverrides extends Logging { }), expr[MillisToTimestamp]( "Converts the number of milliseconds from unix epoch to a timestamp", - ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, TypeSig.integral, TypeSig.integral), (a, conf, p, r) => new UnaryExprMeta[MillisToTimestamp](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = @@ -1144,7 +1144,7 @@ object GpuOverrides extends Logging { }), expr[MicrosToTimestamp]( "Converts the number of microseconds from unix epoch to a timestamp", - ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, TypeSig.integral, TypeSig.integral), (a, conf, p, r) => new UnaryExprMeta[MicrosToTimestamp](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = @@ -2108,14 +2108,12 @@ object GpuOverrides extends Logging { "Max aggregate operator", ExprChecksImpl( ExprChecks.reductionAndGroupByAgg( - (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) - .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + - TypeSig.STRUCT + TypeSig.ARRAY), + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + + TypeSig.ARRAY).nested(), TypeSig.orderable, Seq(ParamCheck("input", - (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) - .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + - TypeSig.STRUCT + TypeSig.ARRAY), + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + + TypeSig.ARRAY).nested(), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts ++ ExprChecks.windowOnly( @@ -2135,14 +2133,12 @@ object GpuOverrides extends Logging { "Min aggregate operator", ExprChecksImpl( ExprChecks.reductionAndGroupByAgg( - (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) - .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + - TypeSig.STRUCT + TypeSig.ARRAY), + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + + TypeSig.ARRAY).nested(), TypeSig.orderable, Seq(ParamCheck("input", - (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) - .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + - TypeSig.STRUCT + TypeSig.ARRAY), + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT + + TypeSig.ARRAY).nested(), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts ++ ExprChecks.windowOnly( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala index d1dbcd3a0863..73b280aadd5c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -60,10 +60,12 @@ case class GpuParseUrl(children: Seq[Expression], override def checkInputDataTypes(): TypeCheckResult = { if (children.size > 3 || children.size < 2) { - RapidsErrorUtils.parseUrlWrongNumArgs(children.size) - } else { - super[ExpectsInputTypes].checkInputDataTypes() + RapidsErrorUtils.parseUrlWrongNumArgs(children.size) match { + case res: Some[TypeCheckResult] => return res.get + case _ => // error message has been thrown + } } + super[ExpectsInputTypes].checkInputDataTypes() } private def getPattern(key: UTF8String): RegexProgram = { diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 9259f0d3e90c..5d74bbac86aa 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -84,11 +84,11 @@ object RapidsErrorUtils { throw new AnalysisException(s"$tableIdentifier already exists.") } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } - def invalidUrlException(url: UFT8String, e: Throwable): Throwable = { + def invalidUrlException(url: UTF8String, e: Throwable): Throwable = { new IllegalArgumentException(s"Find an invaild url string ${url.toString}", e) } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 6a569d98248f..115d4e93ba80 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -90,8 +90,8 @@ object RapidsErrorUtils { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b5997e4e6dbb..b7b95016047c 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -88,8 +88,8 @@ object RapidsErrorUtils { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index a0e827150d5a..841b3a960782 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -84,8 +84,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { new ArrayIndexOutOfBoundsException("SQL array indices start at 1") } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 43be246548ac..aaaecc86256a 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -92,8 +92,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } - def parseUrlWrongNumArgs(actual: Int): TypeCheckResult = { - TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments") + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 6dde4e0a8f98..032fed254efd 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -25,7 +25,7 @@ import java.net.URISyntaxException import org.apache.spark.SparkDateTimeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} import org.apache.spark.unsafe.types.UTF8String @@ -92,10 +92,11 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } - def parseUrlWrongNumArgs(actual: Int): Throwable = { + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { throw QueryCompilationErrors.wrongNumArgsError( - "parse_url", Seq("[2, 3]"), actualNumber + "parse_url", Seq("[2, 3]"), actual ) + None } def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index 4f01e3b73793..98437662e2cc 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -649,16 +649,16 @@ Last,S,`last`; `last_value`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,P Last,S,`last`; `last_value`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS Last,S,`last`; `last_value`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS Last,S,`last`; `last_value`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS -Max,S,`max`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS -Max,S,`max`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS -Max,S,`max`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS -Max,S,`max`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS +Max,S,`max`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS +Max,S,`max`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS +Max,S,`max`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS +Max,S,`max`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS Max,S,`max`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS Max,S,`max`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS -Min,S,`min`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS -Min,S,`min`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS -Min,S,`min`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS -Min,S,`min`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS +Min,S,`min`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS +Min,S,`min`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS +Min,S,`min`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS +Min,S,`min`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NA,PS,NS Min,S,`min`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS Min,S,`min`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,NS,NS PivotFirst,S, ,None,aggregation,pivotColumn,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS