Skip to content

Commit

Permalink
Let big data gen set nullability recursively (NVIDIA#10728)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Apr 29, 2024
1 parent df3f0af commit dfa2ec7
Showing 1 changed file with 49 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ abstract class DataGen(var conf: ColumnConf,
this
}

def setNullProbabilityRecursively(probability: Double): DataGen = {
this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability))
children.foreach {
case (_, dataGen) =>
dataGen.setNullProbabilityRecursively(probability)
}
this
}

/**
* Set a specific location to seed mapping for the value generation.
*/
Expand Down Expand Up @@ -672,6 +681,7 @@ abstract class DataGen(var conf: ColumnConf,
* Get the default value generator for this specific data gen.
*/
protected def getValGen: GeneratorFunction
def children: Seq[(String, DataGen)]

/**
* Get the final ready to use GeneratorFunction for the data generator.
Expand Down Expand Up @@ -823,6 +833,8 @@ class BooleanGen(conf: ColumnConf,
override def dataType: DataType = BooleanType

override protected def getValGen: GeneratorFunction = BooleanGenFunc()

override def children: Seq[(String, DataGen)] = Seq.empty
}

/**
Expand Down Expand Up @@ -878,6 +890,8 @@ class ByteGen(conf: ColumnConf,
extends DataGen(conf, defaultValueRange) {
override def getValGen: GeneratorFunction = ByteGenFunc()
override def dataType: DataType = ByteType

override def children: Seq[(String, DataGen)] = Seq.empty
}

/**
Expand Down Expand Up @@ -935,6 +949,8 @@ class ShortGen(conf: ColumnConf,
override def getValGen: GeneratorFunction = ShortGenFunc()

override def dataType: DataType = ShortType

override def children: Seq[(String, DataGen)] = Seq.empty
}

/**
Expand Down Expand Up @@ -991,6 +1007,8 @@ class IntGen(conf: ColumnConf,
override def getValGen: GeneratorFunction = IntGenFunc()

override def dataType: DataType = IntegerType

override def children: Seq[(String, DataGen)] = Seq.empty
}

/**
Expand Down Expand Up @@ -1045,6 +1063,8 @@ class LongGen(conf: ColumnConf,
override def getValGen: GeneratorFunction = LongGenFunc()

override def dataType: DataType = LongType

override def children: Seq[(String, DataGen)] = Seq.empty
}

case class Decimal32GenFunc(
Expand Down Expand Up @@ -1284,6 +1304,8 @@ class DecimalGen(dt: DecimalType,
val max = DecimalGen.genMaxUnscaled(dt.precision)
DecimalGenFunc(dt.precision, dt.scale, -max, max)
}

override def children: Seq[(String, DataGen)] = Seq.empty
}

/**
Expand Down Expand Up @@ -1341,6 +1363,8 @@ class TimestampGen(conf: ColumnConf,
override protected def getValGen: GeneratorFunction = TimestampGenFunc()

override def dataType: DataType = TimestampType

override def children: Seq[(String, DataGen)] = Seq.empty
}

object BigDataGenConsts {
Expand Down Expand Up @@ -1418,6 +1442,8 @@ class DateGen(conf: ColumnConf,
override protected def getValGen: GeneratorFunction = DateGenFunc()

override def dataType: DataType = DateType

override def children: Seq[(String, DataGen)] = Seq.empty
}

/**
Expand All @@ -1440,6 +1466,8 @@ class DoubleGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)])
override def dataType: DataType = DoubleType

override protected def getValGen: GeneratorFunction = DoubleGenFunc()

override def children: Seq[(String, DataGen)] = Seq.empty
}

/**
Expand All @@ -1462,6 +1490,8 @@ class FloatGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)])
override def dataType: DataType = FloatType

override protected def getValGen: GeneratorFunction = FloatGenFunc()

override def children: Seq[(String, DataGen)] = Seq.empty
}

trait JSONType {
Expand Down Expand Up @@ -1648,6 +1678,8 @@ class StringGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)])
override def dataType: DataType = StringType

override protected def getValGen: GeneratorFunction = ASCIIGenFunc()

override def children: Seq[(String, DataGen)] = Seq.empty
}

case class StructGenFunc(childGens: Array[GeneratorFunction]) extends GeneratorFunction {
Expand Down Expand Up @@ -1752,6 +1784,8 @@ class ArrayGen(child: DataGen,
None
}
}

override def children: Seq[(String, DataGen)] = Seq(("data", child))
}

case class MapGenFunc(
Expand Down Expand Up @@ -1816,6 +1850,8 @@ class MapGen(key: DataGen,
None
}
}

override def children: Seq[(String, DataGen)] = Seq(("key", key), ("value", value))
}


Expand Down Expand Up @@ -1864,6 +1900,11 @@ class ColumnGen(val dataGen: DataGen) {
this
}

def setNullProbabilityRecursively(probability: Double): ColumnGen = {
dataGen.setNullProbabilityRecursively(probability)
this
}

def setNullGen(f: NullGeneratorFunction): ColumnGen = {
dataGen.setNullGen(f)
this
Expand Down Expand Up @@ -1973,6 +2014,14 @@ class TableGen(val columns: Seq[(String, ColumnGen)], numRows: Long) {
this
}

def setNullProbabilityRecursively(probability: Double): TableGen = {
columns.foreach {
case (_, columnGen) =>
columnGen.setNullProbabilityRecursively(probability)
}
this
}

/**
* Convert this table into a `DataFrame` that can be
* written out or used directly. Writing it out to parquet
Expand Down

0 comments on commit dfa2ec7

Please sign in to comment.