From a0a86ea6bcef2f176573510401f13d16a5dd9ae6 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Fri, 1 Jul 2022 22:46:40 +0100 Subject: [PATCH] add stateful uniform sampling of `U64` --- .../compiler/xla/client/lib/prng.cpp | 65 ++++- .../tensorflow/compiler/xla/client/lib/prng.h | 9 + src/Compiler/Eval.idr | 10 + src/Compiler/Expr.idr | 11 + .../Compiler/Xla/Client/Lib/PRNG.idr | 15 +- .../Compiler/Xla/Client/Lib/PRNG.idr | 22 +- src/Tensor.idr | 115 +++++--- test/Unit/TestTensor.idr | 261 ++++++++++++------ 8 files changed, 364 insertions(+), 144 deletions(-) diff --git a/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp b/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp index 48e450187..925e23695 100644 --- a/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp +++ b/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp @@ -36,6 +36,32 @@ xla::BitGeneratorTy BitGenerator(int bit_generator) { return bit_generator_; } +RngOutput* UniformDistribution( + std::function f, + XlaOp& key, + XlaOp& initial_state, + int bit_generator, + XlaOp& minval, + XlaOp& maxval, + Shape& shape +) { + auto& key_ = reinterpret_cast(key); + auto& initial_state_ = reinterpret_cast(initial_state); + xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator); + auto& minval_ = reinterpret_cast(minval); + auto& maxval_ = reinterpret_cast(maxval); + auto& shape_ = reinterpret_cast(shape); + + xla::RngOutput res = f(key_, initial_state_, bit_generator_, minval_, maxval_, shape_); + + return new RngOutput { + value: reinterpret_cast(new xla::XlaOp(res.value)), + state: reinterpret_cast(new xla::XlaOp(res.state)) + }; +} + extern "C" { RngOutput* UniformFloatingPointDistribution( XlaOp& key, @@ -45,21 +71,34 @@ extern "C" { XlaOp& maxval, Shape& shape ) { - auto& key_ = reinterpret_cast(key); - auto& initial_state_ = reinterpret_cast(initial_state); - xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator); - auto& minval_ = reinterpret_cast(minval); - auto& maxval_ = reinterpret_cast(maxval); - auto& shape_ = reinterpret_cast(shape); - - xla::RngOutput res = xla::UniformFloatingPointDistribution( - key_, initial_state_, bit_generator_, minval_, maxval_, shape_ + return UniformDistribution( + xla::UniformFloatingPointDistribution, + key, + initial_state, + bit_generator, + minval, + maxval, + shape ); + } - return new RngOutput { - value: reinterpret_cast(new xla::XlaOp(res.value)), - state: reinterpret_cast(new xla::XlaOp(res.state)) - }; + RngOutput* UniformIntDistribution( + XlaOp& key, + XlaOp& initial_state, + int bit_generator, + XlaOp& minval, + XlaOp& maxval, + Shape& shape + ) { + return UniformDistribution( + xla::UniformIntDistribution, + key, + initial_state, + bit_generator, + minval, + maxval, + shape + ); } RngOutput* NormalFloatingPointDistribution( diff --git a/backend/src/tensorflow/compiler/xla/client/lib/prng.h b/backend/src/tensorflow/compiler/xla/client/lib/prng.h index cf5328ab0..6e1e9d73c 100644 --- a/backend/src/tensorflow/compiler/xla/client/lib/prng.h +++ b/backend/src/tensorflow/compiler/xla/client/lib/prng.h @@ -35,6 +35,15 @@ extern "C" { Shape& shape ); + RngOutput* UniformIntDistribution( + XlaOp& key, + XlaOp& initial_state, + int bit_generator, + XlaOp& minval, + XlaOp& maxval, + Shape& shape + ); + RngOutput* NormalFloatingPointDistribution( XlaOp& key, XlaOp& initial_state, int bit_generator, Shape& shape ); diff --git a/src/Compiler/Eval.idr b/src/Compiler/Eval.idr index a36b9a2de..17f3f08f5 100644 --- a/src/Compiler/Eval.idr +++ b/src/Compiler/Eval.idr @@ -198,6 +198,16 @@ enqueue e@(UniformFloatingPoint key initialState minval maxval shape) = cached e !(mkShape {dtype=F64} shape) (builder, _) <- get tuple builder [value rngOutput, state rngOutput] +enqueue e@(UniformUInt key initialState minval maxval shape) = cached e $ do + rngOutput <- uniformIntDistribution + !(enqueue key) + !(enqueue initialState) + ThreeFry + !(enqueue minval) + !(enqueue maxval) + !(mkShape {dtype=U64} shape) + (builder, _) <- get + tuple builder [value rngOutput, state rngOutput] enqueue e@(NormalFloatingPoint key initialState shape) = cached e $ do rngOutput <- normalFloatingPointDistribution !(enqueue key) !(enqueue initialState) ThreeFry !(mkShape {dtype=F64} shape) diff --git a/src/Compiler/Expr.idr b/src/Compiler/Expr.idr index b7d7e310c..1ffb373bc 100644 --- a/src/Compiler/Expr.idr +++ b/src/Compiler/Expr.idr @@ -97,6 +97,7 @@ data Expr : Type where Cholesky : Expr -> Expr TriangularSolve : Expr -> Expr -> Bool -> Expr UniformFloatingPoint : Expr -> Expr -> Expr -> Expr -> Shape -> Expr + UniformUInt : Expr -> Expr -> Expr -> Expr -> Shape -> Expr NormalFloatingPoint : Expr -> Expr -> Shape -> Expr export @@ -200,6 +201,9 @@ Prelude.Eq Expr where (UniformFloatingPoint key initialState minval maxval shape) == (UniformFloatingPoint key' initialState' minval' maxval' shape') = key == key' && initialState == initialState' && minval == minval' && maxval == maxval' + (UniformUInt key initialState minval maxval shape) == + (UniformUInt key' initialState' minval' maxval' shape') = + key == key' && initialState == initialState' && minval == minval' && maxval == maxval' (NormalFloatingPoint key initialState shape) == (NormalFloatingPoint key' initialState' shape') = key == key' && initialState == initialState' _ == _ = False @@ -318,6 +322,13 @@ Hashable Expr where `hashWithSalt` minval `hashWithSalt` maxval `hashWithSalt` shape + hashWithSalt salt (UniformUInt key initialState minval maxval shape) = salt + `hashWithSalt` "UniformUInt" + `hashWithSalt` key + `hashWithSalt` initialState + `hashWithSalt` minval + `hashWithSalt` maxval + `hashWithSalt` shape hashWithSalt salt (NormalFloatingPoint key initialState shape) = salt `hashWithSalt` "NormalFloatingPoint" `hashWithSalt` key diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr index 13fe26529..4116e4a90 100644 --- a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr @@ -20,18 +20,23 @@ import System.FFI import Compiler.Xla.Prim.Util public export -RngOutput : Type -RngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)] +PrimRngOutput : Type +PrimRngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)] export %foreign (libxla "delete_RngOutput") -prim__delete : RngOutput -> PrimIO () +prim__delete : PrimRngOutput -> PrimIO () export %foreign (libxla "UniformFloatingPointDistribution") prim__uniformFloatingPointDistribution: - GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO RngOutput + GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput + +export +%foreign (libxla "UniformIntDistribution") +prim__uniformIntDistribution: + GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput export %foreign (libxla "NormalFloatingPointDistribution") -prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO RngOutput +prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO PrimRngOutput diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr index 4f61b5e6a..b4e324ca9 100644 --- a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr @@ -28,26 +28,24 @@ Cast BitGenerator Int where cast ThreeFry = 0 cast Philox = 1 -%hide Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.PRNG.RngOutput - public export record RngOutput where constructor MkRngOutput value : XlaOp state : XlaOp -export -uniformFloatingPointDistribution : +uniformDistribution : + (GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput) -> HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput -uniformFloatingPointDistribution +uniformDistribution + f (MkXlaOp key) (MkXlaOp initialState) bitGenerator (MkXlaOp minval) (MkXlaOp maxval) (MkShape shape) = do - rngOutput <- primIO $ prim__uniformFloatingPointDistribution - key initialState (cast bitGenerator) minval maxval shape + rngOutput <- primIO $ f key initialState (cast bitGenerator) minval maxval shape let value = getField rngOutput "value" state = getField rngOutput "state" primIO $ prim__delete rngOutput @@ -55,6 +53,16 @@ uniformFloatingPointDistribution state <- onCollectAny state XlaOp.delete pure (MkRngOutput {value = MkXlaOp value} {state = MkXlaOp state}) +export +uniformFloatingPointDistribution : + HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput +uniformFloatingPointDistribution = uniformDistribution prim__uniformFloatingPointDistribution + +export +uniformIntDistribution: + HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput +uniformIntDistribution = uniformDistribution prim__uniformIntDistribution + export normalFloatingPointDistribution : HasIO io => XlaOp -> XlaOp -> BitGenerator -> Shape -> io RngOutput diff --git a/src/Tensor.idr b/src/Tensor.idr index 8ea0cf494..bae593743 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -1161,48 +1161,79 @@ Rand = State (Tensor [1] U64) inf : Tensor [] F64 inf = fromDouble (1.0 / 0.0) -||| Generate independent and identically distributed (IID) uniform samples bounded element-wise -||| between `bound` and `bound'`. -||| -||| `bound` and `bound'` need not be ordered, and samples will be generated, elementwise, in -||| [min bound bound', max bound bound'). The exception is where the bounds are equal, in which -||| case: if the bounds are finite, samples are generated at the common bound, else samples are NaN. -||| -||| The generated samples are a deterministic function of the input key and state, but may vary -||| between backends and library versions. -||| -||| Example usage, multiplying two uniform samples -||| ``` -||| x : Tensor [3] F64 -||| x = let key = fromLiteral 2 -||| rng = uniform key (fill 0.0) (fill 1.0) -||| initialState = fromLiteral [0] -||| in evalState initialState [| rng * rng |] -||| ``` -||| -||| @key Determines the stream of generated samples. -||| @bound A bound of the samples. See full docstring for details. -||| @bound' A bound of the samples. See full docstring for details. -export -uniform : - {shape : _} -> - (key : Tensor [] U64) -> - (bound, bound' : Tensor shape F64) -> - Rand (Tensor shape F64) -uniform (MkTensor key) bound bound' = - let minval@(MkTensor minvalExpr) = min bound bound' - maxval@(MkTensor maxvalExpr) = max bound bound' - in ST $ \(MkTensor initialState) => - let valueState = UniformFloatingPoint key initialState minvalExpr maxvalExpr shape - value = MkTensor $ GetTupleElement 0 valueState - -- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663 - -- samples between -inf and 0 should be at -inf, but XLA produces nan - -- similarly, samples in (inf, inf) should be at inf and respectively for -inf - inf = broadcast inf - value = select (minval == - inf && maxval == fill 0) (- inf) value - value = select (minval == inf && maxval == inf) inf value - value = select (minval == - inf && maxval == - inf) (- inf) value - in Id (MkTensor $ GetTupleElement 1 valueState, value) +namespace F64 + ||| Generate independent and identically distributed (IID) uniform samples bounded element-wise + ||| between `bound` and `bound'`. `bound` and `bound'` need not be ordered, and are both + ||| inclusive bounds. + ||| + ||| The generated samples are a deterministic function of the input key and state, but may vary + ||| between backends and library versions. + ||| + ||| Example usage, multiplying two uniform samples + ||| ``` + ||| x : Tensor [3] F64 + ||| x = let key = fromLiteral 2 + ||| rng = uniform key (fill 0.0) (fill 1.0) + ||| initialState = fromLiteral [0] + ||| in evalState initialState [| rng * rng |] + ||| ``` + ||| + ||| @key Determines the stream of generated samples. + ||| @bound A bound of the samples. See full docstring for details. + ||| @bound' A bound of the samples. See full docstring for details. + export + uniform : + {shape : _} -> + (key : Tensor [] U64) -> + (bound, bound' : Tensor shape F64) -> + Rand (Tensor shape F64) + uniform (MkTensor key) bound bound' = + let minval@(MkTensor minvalExpr) = min bound bound' + maxval@(MkTensor maxvalExpr) = max bound bound' + in ST $ \(MkTensor initialState) => + let valueState = UniformFloatingPoint key initialState minvalExpr maxvalExpr shape + value = MkTensor $ GetTupleElement 0 valueState + -- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663 + -- samples between -inf and 0 should be at -inf, but XLA produces nan + -- similarly, samples in (inf, inf) should be at inf and respectively for -inf + inf = broadcast inf + value = select (minval == - inf && maxval == fill 0) (- inf) value + value = select (minval == inf && maxval == inf) inf value + value = select (minval == - inf && maxval == - inf) (- inf) value + in Id (MkTensor $ GetTupleElement 1 valueState, value) + +namespace U64 + ||| Generate independent and identically distributed (IID) uniform samples bounded element-wise + ||| between `bound` and `bound'`. `bound` and `bound'` need not be ordered, and are both + ||| inclusive bounds. + ||| + ||| The generated samples are a deterministic function of the input key and state, but may vary + ||| between backends and library versions. + ||| + ||| Example usage, multiplying two uniform samples + ||| ``` + ||| x : Tensor [3] U64 + ||| x = let key = fromLiteral 2 + ||| rng = uniform key (fill 0) (fill 100) + ||| initialState = fromLiteral [0] + ||| in evalState initialState [| rng * rng |] + ||| ``` + ||| + ||| @key Determines the stream of generated samples. + ||| @bound A bound of the samples. See full docstring for details. + ||| @bound' A bound of the samples. See full docstring for details. + export + uniform : + {shape : _} -> + (key : Tensor [] U64) -> + (bound, bound' : Tensor shape U64) -> + Rand (Tensor shape U64) + uniform (MkTensor key) bound bound' = + let MkTensor minval = min bound bound' + MkTensor maxval = max bound bound' + in ST $ \(MkTensor initialState) => + let valueState = UniformUInt key initialState minval maxval shape + in Id (MkTensor $ GetTupleElement 1 valueState, MkTensor $ GetTupleElement 0 valueState) ||| Generate independent and identically distributed (IID) samples from the standard normal ||| distribution. diff --git a/test/Unit/TestTensor.idr b/test/Unit/TestTensor.idr index a47aac22d..2a0c47767 100644 --- a/test/Unit/TestTensor.idr +++ b/test/Unit/TestTensor.idr @@ -1108,105 +1108,207 @@ range n = cast (Vect.range n) product1 : (x : Nat) -> product (the (List Nat) [x]) = x product1 x = rewrite plusZeroRightNeutral x in Refl +export iidKolmogorovSmirnov : - {shape : _} -> Tensor shape F64 -> (Tensor shape F64 -> Tensor shape F64) -> Tensor [] F64 + {shape : _} -> Tensor shape dtype -> (Tensor shape dtype -> Tensor shape F64) -> Tensor [] F64 iidKolmogorovSmirnov samples cdf = let n : Nat n = product shape indices : Tensor [n] F64 := cast (fromLiteral {dtype=U64} (range n)) sampleSize : Tensor [] F64 := cast (fromLiteral {dtype=U64} (Scalar n)) - samplesFlat := reshape {sizesEqual=sym (product1 n)} {to=[n]} (cdf samples) - deviationFromCDF : Tensor [n] F64 := indices / sampleSize - (sort (<) 0 samplesFlat) + cdfs := reshape {sizesEqual=sym (product1 n)} {to=[n]} (cdf samples) + deviationFromCDF : Tensor [n] F64 := indices / sampleSize - (sort (<) 0 cdfs) in reduce @{Max} [0] (abs deviationFromCDF) -covering -uniform : Property -uniform = withTests 20 . property $ do - bound <- forAll (literal [5] finiteDoubles) - bound' <- forAll (literal [5] finiteDoubles) - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) +namespace F64 + export covering + uniform : Property + uniform = withTests 20 . property $ do + bound <- forAll (literal [5] finiteDoubles) + bound' <- forAll (literal [5] finiteDoubles) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) - let bound = fromLiteral bound - bound' = fromLiteral bound' - key = fromLiteral key - seed = fromLiteral seed - samples = evalState seed (uniform key (broadcast bound) (broadcast bound')) + let bound = fromLiteral bound + bound' = fromLiteral bound' + key = fromLiteral key + seed = fromLiteral seed + samples = evalState seed (uniform key (broadcast bound) (broadcast bound')) - uniformCdf : Tensor [2000, 5] F64 -> Tensor [2000, 5] F64 - uniformCdf x = (x - broadcast bound) / broadcast (bound' - bound) + uniformCdf : Tensor [2000, 5] F64 -> Tensor [2000, 5] F64 + uniformCdf x = (x - broadcast bound) / broadcast {to=[2000, 5]} (bound' - bound) - ksTest := iidKolmogorovSmirnov samples uniformCdf + ksTest := iidKolmogorovSmirnov samples uniformCdf - diff (toLiteral ksTest) (<) 0.01 + diff (toLiteral ksTest) (<) 0.01 -covering -uniformForNonFiniteBounds : Property -uniformForNonFiniteBounds = property $ do - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) + export covering + uniformForNonFiniteBounds : Property + uniformForNonFiniteBounds = property $ do + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) - let bound = fromLiteral [0.0, 0.0, 0.0, -inf, -inf, -inf, inf, inf, nan] - bound' = fromLiteral [-inf, inf, nan, -inf, inf, nan, inf, nan, nan] - key = fromLiteral key - seed = fromLiteral seed - samples = evalState seed (uniform key (broadcast bound) (broadcast bound')) + let bound = fromLiteral [0.0, 0.0, 0.0, -inf, -inf, -inf, inf, inf, nan] + bound' = fromLiteral [-inf, inf, nan, -inf, inf, nan, inf, nan, nan] + key = fromLiteral key + seed = fromLiteral seed + samples = evalState seed (uniform key (broadcast bound) (broadcast bound')) - samples ===# fromLiteral [-inf, inf, nan, -inf, nan, nan, inf, nan, nan] + samples ===# fromLiteral [-inf, inf, nan, -inf, nan, nan, inf, nan, nan] -covering -uniformForFiniteEqualBounds : Property -uniformForFiniteEqualBounds = withTests 20 . property $ do - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) + export covering + uniformForFiniteEqualBounds : Property + uniformForFiniteEqualBounds = withTests 20 . property $ do + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) - let bound = fromLiteral [-1.0, 0.0, 1.0] - key = fromLiteral key - seed = fromLiteral seed - samples = evalState seed (uniform key bound bound) + let bound = fromLiteral [-1.0, 0.0, 1.0] + key = fromLiteral key + seed = fromLiteral seed + samples = evalState seed (uniform key bound bound) - samples ===# fromLiteral [-1.0, 0.0, 1.0] + samples ===# fromLiteral [-1.0, 0.0, 1.0] -covering -uniformSeedIsUpdated : Property -uniformSeedIsUpdated = withTests 20 . property $ do - bound <- forAll (literal [10] doubles) - bound' <- forAll (literal [10] doubles) - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) + export covering + uniformSeedIsUpdated : Property + uniformSeedIsUpdated = withTests 20 . property $ do + bound <- forAll (literal [10] doubles) + bound' <- forAll (literal [10] doubles) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral bound + bound' = fromLiteral bound' + key = fromLiteral key + seed = fromLiteral seed + + rng = F64.uniform key {shape=[10]} (broadcast bound) (broadcast bound') + (seed', sample) = runState seed rng + (seed'', sample') = runState seed' rng + + diff (toLiteral seed') (/=) (toLiteral seed) + diff (toLiteral seed'') (/=) (toLiteral seed') + diff (toLiteral sample') (/=) (toLiteral sample) - let bound = fromLiteral bound - bound' = fromLiteral bound' - key = fromLiteral key - seed = fromLiteral seed + export covering + uniformIsReproducible : Property + uniformIsReproducible = withTests 20 . property $ do + bound <- forAll (literal [10] doubles) + bound' <- forAll (literal [10] doubles) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) - rng = uniform key {shape=[10]} (broadcast bound) (broadcast bound') - (seed', sample) = runState seed rng - (seed'', sample') = runState seed' rng + let bound = fromLiteral bound + bound' = fromLiteral bound' + key = fromLiteral key + seed = fromLiteral seed - diff (toLiteral seed') (/=) (toLiteral seed) - diff (toLiteral seed'') (/=) (toLiteral seed') - diff (toLiteral sample') (/=) (toLiteral sample) + rng = F64.uniform {shape=[10]} key (broadcast bound) (broadcast bound') + sample = evalState seed rng + sample' = evalState seed rng -covering -uniformIsReproducible : Property -uniformIsReproducible = withTests 20 . property $ do - bound <- forAll (literal [10] doubles) - bound' <- forAll (literal [10] doubles) - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) + sample ===# sample' - let bound = fromLiteral bound - bound' = fromLiteral bound' - key = fromLiteral key - seed = fromLiteral seed +namespace U64 + export covering + uniform : Property + uniform = withTests 20 . property $ do + bound <- forAll (literal [10] nats) + bound' <- forAll (literal [10] nats) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) - rng = uniform {shape=[10]} key (broadcast bound) (broadcast bound') - sample = evalState seed rng - sample' = evalState seed rng + let bound = fromLiteral {dtype=U64} bound + bound' = fromLiteral {dtype=U64} bound' + key = fromLiteral key + seed = fromLiteral seed - sample ===# sample' + samples : Tensor [1, 10] U64 := + evalState seed (uniform key (broadcast bound) (broadcast bound')) + + uniformCdf : Tensor [1, 10] U64 -> Tensor [1, 10] F64 + uniformCdf x = + let bound : Tensor [10] F64 = cast bound + bound' : Tensor [10] F64 = cast bound' + distance : Tensor [1, 10] F64 = cast x - broadcast bound + in distance / (broadcast (bound - bound')) + + ksTest := iidKolmogorovSmirnov samples uniformCdf + + samples ===# fill 0 + diff (toLiteral ksTest) (<) 0.01 + + export covering + uniformBoundsAreInclusive : Property + uniformBoundsAreInclusive = property $ do + bound <- forAll (literal [100] nats) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral bound + bound' = bound + fill 2 + key = fromLiteral key + seed = fromLiteral seed + + samples = evalState seed (U64.uniform key bound bound') + + diff (toLiteral samples) (\x, y => any id [| x == y |]) (toLiteral bound) + diff (toLiteral samples) (\x, y => any id [| x == y |]) (toLiteral bound') + + export covering + uniformForEqualBounds : Property + uniformForEqualBounds = withTests 20 . property $ do + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral $ the (Literal [3] Nat) [0, 1, 5] + key = fromLiteral key + seed = fromLiteral seed + + samples : Tensor [3] U64 = evalState seed (uniform key bound bound) + + samples ===# fromLiteral [0, 1, 5] + + export covering + uniformSeedIsUpdated : Property + uniformSeedIsUpdated = withTests 20 . property $ do + bound <- forAll (literal [10] nats) + bound' <- forAll (literal [10] nats) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral bound + bound' = fromLiteral bound' + key = fromLiteral key + seed = fromLiteral seed + + rng = U64.uniform key {shape=[10]} (broadcast bound) (broadcast bound') + (seed', sample) = runState seed rng + (seed'', sample') = runState seed' rng + + diff (toLiteral seed') (/=) (toLiteral seed) + diff (toLiteral seed'') (/=) (toLiteral seed') + diff (toLiteral sample') (/=) (toLiteral sample) + + export covering + uniformIsReproducible : Property + uniformIsReproducible = withTests 20 . property $ do + bound <- forAll (literal [10] nats) + bound' <- forAll (literal [10] nats) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral bound + bound' = fromLiteral bound' + key = fromLiteral key + seed = fromLiteral seed + + rng = U64.uniform {shape=[10]} key (broadcast bound) (broadcast bound') + sample = evalState seed rng + sample' = evalState seed rng + + sample ===# sample' covering normal : Property @@ -1311,11 +1413,16 @@ group = MkGroup "Tensor" $ [ , (#"(|\) and (/|) result and inverse"#, triangularSolveResultAndInverse) , (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems) , ("trace", trace) - , ("uniform", uniform) - , ("uniform for infinite and NaN bounds", uniformForNonFiniteBounds) - , ("uniform is not NaN for finite equal bounds", uniformForFiniteEqualBounds) - , ("uniform updates seed", uniformSeedIsUpdated) - , ("uniform produces same samples for same seed", uniformIsReproducible) + , ("uniform F64", F64.uniform) + , ("uniform F64 for infinite and NaN bounds", uniformForNonFiniteBounds) + , ("uniform F64 is not NaN for finite equal bounds", uniformForFiniteEqualBounds) + , ("uniform F64 updates seed", F64.uniformSeedIsUpdated) + , ("uniform F64 produces same samples for same seed", F64.uniformIsReproducible) + , ("uniform U64", U64.uniform) + , ("uniform U64 bounds are inclusive", uniformBoundsAreInclusive) + , ("uniform U64 for equal bounds", U64.uniformForEqualBounds) + , ("uniform U64 updates seed", U64.uniformSeedIsUpdated) + , ("uniform U64 produces same samples for same seed", U64.uniformIsReproducible) , ("normal", normal) , ("normal updates seed", normalSeedIsUpdated) , ("normal produces same samples for same seed", normalIsReproducible)