Skip to content

Commit

Permalink
public randGamma API with shape/scale expression
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Oct 11, 2024
1 parent 03c951f commit 178acc1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import org.apache.spark.util.Utils
object functions {
private def withExpr(expr: Expression): Column = Column(expr)

def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)
def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(seed: Column, shape: Column, scale: Column): Column = withExpr(RandGamma(seed.expr, shape.expr, scale.expr)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(shape: Column, scale: Column): Column = randGamma(lit(Utils.random.nextLong), shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)

def randLaplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar
assert(math.abs(gammaMean - 4.0) < 0.5)
assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5)
}

"has correct mean and standard deviation from shape/scale column" - {
val sourceDF = spark
.range(100000)
.withColumn("shape", lit(2.0))
.withColumn("scale", lit(2.0))
.select(randGamma(col("shape"), col("shape")))
val stats = sourceDF
.agg(
mean("gamma_random").as("mean"),
stddev("gamma_random").as("stddev")
)
.collect()(0)

val gammaMean = stats.getAs[Double]("mean")
val gammaStddev = stats.getAs[Double]("stddev")

// Gamma distribution with shape=2.0 and scale=2.0 has mean=4.0 and stddev=sqrt(8.0)
assert(gammaMean > 0)
assert(math.abs(gammaMean - 4.0) < 0.5)
assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5)
}
}

'rand_laplace - {
Expand Down

0 comments on commit 178acc1

Please sign in to comment.