diff --git a/hail/python/hail/expr/__init__.py b/hail/python/hail/expr/__init__.py index bff07932fd9..37d82d3117f 100644 --- a/hail/python/hail/expr/__init__.py +++ b/hail/python/hail/expr/__init__.py @@ -218,6 +218,8 @@ rand_gamma, rand_int32, rand_int64, + rand_hyper, + rand_multi_hyper, rand_norm, rand_norm2d, rand_pois, @@ -395,6 +397,8 @@ 'rand_gamma', 'rand_cat', 'rand_dirichlet', + 'rand_hyper', + 'rand_multi_hyper', 'sqrt', 'corr', 'str', diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 0558be5d0f5..bce627436f4 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -3300,6 +3300,16 @@ def rand_dirichlet(a, seed=None) -> ArrayExpression: return hl.bind(lambda x: x / hl.sum(x), a.map(lambda p: hl.if_else(p == 0.0, 0.0, hl.rand_gamma(p, 1, seed=seed)))) +@typecheck(ngood=expr_int32, nbad=expr_int32, nsample=expr_int32, seed=nullable(int)) +def rand_hyper(ngood, nbad, nsample, seed=None) -> Int32Expression: + return _seeded_func("rand_hyper", tint32, seed, ngood, nbad, nsample) + + +@typecheck(colors=expr_array(expr_int32), nsample=expr_int32, seed=nullable(int)) +def rand_multi_hyper(colors, nsample, seed=None) -> Int32Expression: + return _seeded_func("rand_multi_hyper", tarray(tint32), seed, colors, nsample) + + @typecheck(x=oneof(expr_float64, expr_ndarray(expr_float64))) @ndarray_broadcasting def sqrt(x) -> Float64Expression: diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index 1a9477533e7..af09796cb7a 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -69,6 +69,8 @@ def test_random_function(rand_f): test_random_function(lambda: hl.rand_gamma(1, 1)) test_random_function(lambda: hl.rand_cat(hl.array([1, 1, 1, 1]))) test_random_function(lambda: hl.rand_dirichlet(hl.array([1, 1, 1, 1]))) + test_random_function(lambda: hl.rand_hyper(5, 10, 4)) + test_random_function(lambda: hl.rand_multi_hyper([5, 2, 8], 4)) def test_range(self): def same_as_python(*args): diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala index 9ba56ad97b6..39917358e44 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala @@ -389,6 +389,74 @@ object RandomSeededFunctions extends RegistryFunctions { primitive(cb.memoize(rng.invoke[Double, Double, Double]("rgamma", a.value, scale.value))) } + registerSCode4( + "rand_hyper", + TRNGState, + TInt32, + TInt32, + TInt32, + TInt32, + { + case (_: Type, _: SType, _: SType, _: SType, _: SType) => SInt32 + }, + ) { + case ( + _, + cb, + _, + rngState: SRNGStateValue, + nGood: SInt32Value, + nBad: SInt32Value, + nSample: SInt32Value, + _, + ) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Double, Double, Double, Double]( + "rhyper", + nGood.value.toD, + nBad.value.toD, + nSample.value.toD, + ).toI)) + } + + registerSCode3( + "rand_multi_hyper", + TRNGState, + TArray(TInt32), + TInt32, + TArray(TInt32), + { + case (_: Type, _: SType, _: SType, _: SType) => + SIndexablePointer(PCanonicalArray(PInt32(required = true))) + }, + ) { + case (r, cb, _, rngState: SRNGStateValue, colors: SIndexableValue, nSample: SInt32Value, _) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + val (push, finish) = PCanonicalArray(PInt32(required = true)) + .constructFromFunctions(cb, r.region, colors.loadLength(), deepCopy = false) + cb.if_( + colors.hasMissingValues(cb), + cb._fatal("rand_multi_hyper: colors may not contain missing values"), + ) + val remaining = cb.newLocal[Int]("rand_multi_hyper_N", 0) + val toSample = cb.newLocal[Int]("rand_multi_hyper_toSample", nSample.value) + colors.forEachDefined(cb)((cb, _, n) => cb.assign(remaining, remaining + n.asInt.value)) + colors.forEachDefined(cb) { (cb, _, n) => + cb.assign(remaining, remaining - n.asInt.value) + val drawn = cb.memoize(rng.invoke[Double, Double, Double, Double]( + "rhyper", + n.asInt.value.toD, + remaining.toD, + toSample.toD, + ).toI) + cb.assign(toSample, toSample - drawn) + push(cb, IEmitCode.present(cb, primitive(drawn))) + } + finish(cb) + } + registerSCode2( "rand_cat", TRNGState,