From ba5cee138cc3babdb71cd78ee8fc2e31c60bb0af Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Fri, 1 Jul 2022 22:46:40 +0100 Subject: [PATCH] add stateful uniform sampling of `U64` --- .../compiler/xla/client/lib/prng.cpp | 65 +++- .../tensorflow/compiler/xla/client/lib/prng.h | 9 + src/Compiler/Eval.idr | 9 + src/Compiler/Expr.idr | 1 + .../Compiler/Xla/Client/Lib/PRNG.idr | 15 +- .../Compiler/Xla/Client/Lib/PRNG.idr | 22 +- src/Literal.idr | 10 + src/Tensor.idr | 131 +++++--- test/Unit/TestTensor/Sampling.idr | 296 ++++++++++++------ 9 files changed, 384 insertions(+), 174 deletions(-) diff --git a/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp b/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp index 48e450187..925e23695 100644 --- a/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp +++ b/backend/src/tensorflow/compiler/xla/client/lib/prng.cpp @@ -36,6 +36,32 @@ xla::BitGeneratorTy BitGenerator(int bit_generator) { return bit_generator_; } +RngOutput* UniformDistribution( + std::function f, + XlaOp& key, + XlaOp& initial_state, + int bit_generator, + XlaOp& minval, + XlaOp& maxval, + Shape& shape +) { + auto& key_ = reinterpret_cast(key); + auto& initial_state_ = reinterpret_cast(initial_state); + xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator); + auto& minval_ = reinterpret_cast(minval); + auto& maxval_ = reinterpret_cast(maxval); + auto& shape_ = reinterpret_cast(shape); + + xla::RngOutput res = f(key_, initial_state_, bit_generator_, minval_, maxval_, shape_); + + return new RngOutput { + value: reinterpret_cast(new xla::XlaOp(res.value)), + state: reinterpret_cast(new xla::XlaOp(res.state)) + }; +} + extern "C" { RngOutput* UniformFloatingPointDistribution( XlaOp& key, @@ -45,21 +71,34 @@ extern "C" { XlaOp& maxval, Shape& shape ) { - auto& key_ = reinterpret_cast(key); - auto& initial_state_ = reinterpret_cast(initial_state); - xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator); - auto& minval_ = reinterpret_cast(minval); - auto& maxval_ = reinterpret_cast(maxval); - auto& shape_ = reinterpret_cast(shape); - - xla::RngOutput res = xla::UniformFloatingPointDistribution( - key_, initial_state_, bit_generator_, minval_, maxval_, shape_ + return UniformDistribution( + xla::UniformFloatingPointDistribution, + key, + initial_state, + bit_generator, + minval, + maxval, + shape ); + } - return new RngOutput { - value: reinterpret_cast(new xla::XlaOp(res.value)), - state: reinterpret_cast(new xla::XlaOp(res.state)) - }; + RngOutput* UniformIntDistribution( + XlaOp& key, + XlaOp& initial_state, + int bit_generator, + XlaOp& minval, + XlaOp& maxval, + Shape& shape + ) { + return UniformDistribution( + xla::UniformIntDistribution, + key, + initial_state, + bit_generator, + minval, + maxval, + shape + ); } RngOutput* NormalFloatingPointDistribution( diff --git a/backend/src/tensorflow/compiler/xla/client/lib/prng.h b/backend/src/tensorflow/compiler/xla/client/lib/prng.h index cf5328ab0..6e1e9d73c 100644 --- a/backend/src/tensorflow/compiler/xla/client/lib/prng.h +++ b/backend/src/tensorflow/compiler/xla/client/lib/prng.h @@ -35,6 +35,15 @@ extern "C" { Shape& shape ); + RngOutput* UniformIntDistribution( + XlaOp& key, + XlaOp& initial_state, + int bit_generator, + XlaOp& minval, + XlaOp& maxval, + Shape& shape + ); + RngOutput* NormalFloatingPointDistribution( XlaOp& key, XlaOp& initial_state, int bit_generator, Shape& shape ); diff --git a/src/Compiler/Eval.idr b/src/Compiler/Eval.idr index cc10d86ac..99ee4b032 100644 --- a/src/Compiler/Eval.idr +++ b/src/Compiler/Eval.idr @@ -190,6 +190,15 @@ enqueue builder (UniformFloatingPoint key initialState minval maxval shape) = do !(lookup maxval) !(mkShape {dtype=F64} shape) tuple builder [value rngOutput, state rngOutput] +enqueue builder (UniformUInt key initialState minval maxval shape) = do + rngOutput <- uniformIntDistribution + !(lookup key) + !(lookup initialState) + ThreeFry + !(lookup minval) + !(lookup maxval) + !(mkShape {dtype=U64} shape) + tuple builder [value rngOutput, state rngOutput] enqueue builder (NormalFloatingPoint key initialState shape) = do rngOutput <- normalFloatingPointDistribution !(lookup key) !(lookup initialState) ThreeFry !(mkShape {dtype=F64} shape) diff --git a/src/Compiler/Expr.idr b/src/Compiler/Expr.idr index 13e749dea..ec093a1b8 100644 --- a/src/Compiler/Expr.idr +++ b/src/Compiler/Expr.idr @@ -119,4 +119,5 @@ data Expr : Type where Cholesky : Nat -> Expr TriangularSolve : Nat -> Nat -> Bool -> Expr UniformFloatingPoint : Nat -> Nat -> Nat -> Nat -> Shape -> Expr + UniformUInt : Nat -> Nat -> Nat -> Nat -> Shape -> Expr NormalFloatingPoint : Nat -> Nat -> Shape -> Expr diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr index 13fe26529..4116e4a90 100644 --- a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr @@ -20,18 +20,23 @@ import System.FFI import Compiler.Xla.Prim.Util public export -RngOutput : Type -RngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)] +PrimRngOutput : Type +PrimRngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)] export %foreign (libxla "delete_RngOutput") -prim__delete : RngOutput -> PrimIO () +prim__delete : PrimRngOutput -> PrimIO () export %foreign (libxla "UniformFloatingPointDistribution") prim__uniformFloatingPointDistribution: - GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO RngOutput + GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput + +export +%foreign (libxla "UniformIntDistribution") +prim__uniformIntDistribution: + GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput export %foreign (libxla "NormalFloatingPointDistribution") -prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO RngOutput +prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO PrimRngOutput diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr index 4f61b5e6a..b4e324ca9 100644 --- a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr @@ -28,26 +28,24 @@ Cast BitGenerator Int where cast ThreeFry = 0 cast Philox = 1 -%hide Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.PRNG.RngOutput - public export record RngOutput where constructor MkRngOutput value : XlaOp state : XlaOp -export -uniformFloatingPointDistribution : +uniformDistribution : + (GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput) -> HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput -uniformFloatingPointDistribution +uniformDistribution + f (MkXlaOp key) (MkXlaOp initialState) bitGenerator (MkXlaOp minval) (MkXlaOp maxval) (MkShape shape) = do - rngOutput <- primIO $ prim__uniformFloatingPointDistribution - key initialState (cast bitGenerator) minval maxval shape + rngOutput <- primIO $ f key initialState (cast bitGenerator) minval maxval shape let value = getField rngOutput "value" state = getField rngOutput "state" primIO $ prim__delete rngOutput @@ -55,6 +53,16 @@ uniformFloatingPointDistribution state <- onCollectAny state XlaOp.delete pure (MkRngOutput {value = MkXlaOp value} {state = MkXlaOp state}) +export +uniformFloatingPointDistribution : + HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput +uniformFloatingPointDistribution = uniformDistribution prim__uniformFloatingPointDistribution + +export +uniformIntDistribution: + HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput +uniformIntDistribution = uniformDistribution prim__uniformIntDistribution + export normalFloatingPointDistribution : HasIO io => XlaOp -> XlaOp -> BitGenerator -> Shape -> io RngOutput diff --git a/src/Literal.idr b/src/Literal.idr index 4838b7f82..600069000 100644 --- a/src/Literal.idr +++ b/src/Literal.idr @@ -36,6 +36,9 @@ data Literal : Shape -> Type -> Type where Nil : Literal (0 :: ds) a (::) : Literal ds a -> Literal (d :: ds) a -> Literal (S d :: ds) a +export +(.shape) : (xs : Literal shape a) -> (s ** s = shape) + export fromInteger : Integer -> Literal [] Int32 fromInteger = Scalar . cast {to=Int32} @@ -210,3 +213,10 @@ namespace All Scalar : forall x . p x -> All p (Scalar x) Nil : All p [] (::) : All p x -> All p xs -> All p (x :: xs) + +namespace Compare + public export + data Compare : (0 p : a -> b -> Type) -> Literal shape a -> Literal shape b -> Type where + Scalar : forall a, b. p a b -> Compare p (Scalar a) (Scalar b) + Nil : Compare p [] [] + (::) : Compare p as bs -> Compare p ass bss -> Compare p (as :: ass) (bs :: bss) diff --git a/src/Tensor.idr b/src/Tensor.idr index 2076841ac..578245dc6 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -1393,52 +1393,91 @@ Rand = StateT (Tensor [1] U64) Ref inf : Ref $ Tensor [] F64 inf = fromDouble (1.0 / 0.0) -||| Generate independent and identically distributed (IID) uniform samples bounded element-wise -||| between `bound` and `bound'`. -||| -||| `bound` and `bound'` need not be ordered, and samples will be generated, elementwise, in -||| [min !bound !bound', max !bound !bound'). The exception is where the bounds are equal, in which -||| case: if the bounds are finite, samples are generated at the common bound, else samples are NaN. -||| -||| The generated samples are a deterministic function of the input key and state, but may vary -||| between backends and library versions. -||| -||| Example usage, multiplying two uniform samples -||| ``` -||| x : Ref $ Tensor [3] F64 -||| x = do key <- tensor (Scalar 2) -||| rng <- uniform key !(fill 0.0) !(fill 1.0) -||| initialState <- tensor [Scalar 0] -||| evalStateT initialState (do lift $ pure !rng * pure !rng) -||| ``` -||| -||| @key Determines the stream of generated samples. -||| @bound A bound of the samples. See full docstring for details. -||| @bound' A bound of the samples. See full docstring for details. -export -uniform : - {shape : _} -> - (key : Tensor [] U64) -> - (bound, bound' : Tensor shape F64) -> - Ref $ Rand $ Tensor shape F64 -uniform (MkTensor iKey envKey) bound bound' = do - minval@(MkTensor iMinval envMinval) <- min bound bound' - maxval@(MkTensor iMaxval envMaxval) <- max bound bound' - let inf = broadcast !inf - let env = mergeLeft (mergeLeft envKey envMinval) envMaxval - pure $ ST $ \(MkTensor iState envState) => do - i <- new - let env = mergeLeft envState env - env = insert i (UniformFloatingPoint iKey iState iMinval iMaxval shape) env - state = env `end` GetTupleElement 1 i - value = env `end` GetTupleElement 0 i - -- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663 - -- samples between -inf and 0 should be at -inf, but XLA produces nan - -- similarly, samples in (inf, inf) should be at inf and respectively for -inf - value = select !((pure minval == - inf) && (pure maxval == fill 0)) !(- inf) !value - value = select !((pure minval == inf) && (pure maxval == inf)) !inf !value - value = select !((pure minval == - inf) && (pure maxval == - inf)) !(- inf) !value - pure (!state, !value) +namespace F64 + ||| Generate independent and identically distributed (IID) uniform samples bounded element-wise + ||| between `bound` and `bound'`. + ||| + ||| `bound` and `bound'` need not be ordered, and samples will be generated, elementwise, in + ||| [min !bound !bound', max !bound !bound'). The exception is where the bounds are equal, in which + ||| case: if the bounds are finite, samples are generated at the common bound, else samples are NaN. + ||| + ||| The generated samples are a deterministic function of the input key and state, but may vary + ||| between backends and library versions. + ||| + ||| Example usage, multiplying two uniform samples + ||| ``` + ||| x : Ref $ Tensor [3] F64 + ||| x = do key <- tensor (Scalar 2) + ||| rng <- uniform key !(fill 0.0) !(fill 1.0) + ||| initialState <- tensor [Scalar 0] + ||| evalStateT initialState (do lift $ pure !rng * pure !rng) + ||| ``` + ||| + ||| @key Determines the stream of generated samples. + ||| @bound A bound of the samples. See full docstring for details. + ||| @bound' A bound of the samples. See full docstring for details. + export + uniform : + (key : Tensor [] U64) -> + (bound, bound' : Tensor shape F64) -> + Ref $ Rand $ Tensor shape F64 + uniform (MkTensor iKey envKey) bound bound' = do + minval@(MkTensor {shape = _} iMinval envMinval) <- min bound bound' + maxval@(MkTensor iMaxval envMaxval) <- max bound bound' + let inf = broadcast !inf + let env = mergeLeft (mergeLeft envKey envMinval) envMaxval + pure $ ST $ \(MkTensor iState envState) => do + i <- new + let env = mergeLeft envState env + env = insert i (UniformFloatingPoint iKey iState iMinval iMaxval shape) env + state = env `end` GetTupleElement 1 i + value = env `end` GetTupleElement 0 i + -- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663 + -- samples between -inf and 0 should be at -inf, but XLA produces nan + -- similarly, samples in (inf, inf) should be at inf and respectively for -inf + value = select !((pure minval == - inf) && (pure maxval == fill 0)) !(- inf) !value + value = select !((pure minval == inf) && (pure maxval == inf)) !inf !value + value = select !((pure minval == - inf) && (pure maxval == - inf)) !(- inf) !value + pure (!state, !value) + +namespace U64 + ||| Generate independent and identically distributed (IID) uniform samples bounded element-wise + ||| between `bound` and `bound'`. `bound` and `bound'` need not be ordered, and are both + ||| inclusive bounds. + ||| + ||| The generated samples are a deterministic function of the input key and state, but may vary + ||| between backends and library versions. + ||| + ||| Example usage, multiplying two uniform samples + ||| ``` + ||| x : Tensor [3] U64 + ||| x = let key = fromLiteral 2 + ||| rng = uniform key (fill 0) (fill 100) + ||| initialState = fromLiteral [0] + ||| in evalState initialState [| rng * rng |] + ||| ``` + ||| + ||| @key Determines the stream of generated samples. + ||| @bound A bound of the samples. See full docstring for details. + ||| @bound' A bound of the samples. See full docstring for details. + export + uniform : + (key : Tensor [] U64) -> + (lower, upper : Literal shape Nat) -> + {auto 0 boundsOrdered : Compare LT lower upper} -> + Ref $ Rand $ Tensor shape U64 + uniform (MkTensor iKey envKey) lower upper = do + let (shape' ** shapesEq) = lower.shape + MkTensor iLower envLower <- tensor {dtype = U64} {shape = shape'} (rewrite shapesEq in lower) + MkTensor iUpper envUpper <- tensor {dtype = U64} {shape = shape'} (rewrite shapesEq in upper) + let env = mergeLeft (mergeLeft envKey envLower) envUpper + pure $ ST $ \(MkTensor iState envState) => do + i <- new + let env = mergeLeft envState env + env = insert i (UniformUInt iKey iState iLower iUpper shape') env + state = env `end` GetTupleElement 1 i + value = env `end` GetTupleElement 0 i + pure (!state, rewrite sym shapesEq in !value) ||| Generate independent and identically distributed (IID) samples from the standard normal ||| distribution. diff --git a/test/Unit/TestTensor/Sampling.idr b/test/Unit/TestTensor/Sampling.idr index 300031edf..099ce70d2 100644 --- a/test/Unit/TestTensor/Sampling.idr +++ b/test/Unit/TestTensor/Sampling.idr @@ -32,7 +32,7 @@ product1 x = rewrite plusZeroRightNeutral x in Refl partial iidKolmogorovSmirnov : - {shape : _} -> Tensor shape F64 -> (Tensor shape F64 -> Ref $ Tensor shape F64) -> Ref $ Tensor [] F64 + {shape : _} -> Tensor shape dtype -> (Tensor shape dtype -> Ref $ Tensor shape F64) -> Ref $ Tensor [] F64 iidKolmogorovSmirnov samples cdf = do let n : Nat n = product shape @@ -46,103 +46,189 @@ iidKolmogorovSmirnov samples cdf = do Prelude.Ord a => Prelude.Ord (Literal [] a) where compare (Scalar x) (Scalar y) = compare x y -partial -uniform : Property -uniform = withTests 20 . property $ do - bound <- forAll (literal [5] finiteDoubles) - bound' <- forAll (literal [5] finiteDoubles) - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) - - let ksTest = do - let bound = tensor bound - bound' = tensor bound' - bound' = select !(bound' == bound) !(bound' + fill 1.0e-9) !bound' - key <- tensor key - seed <- tensor seed - samples <- evalStateT seed !(uniform key !(broadcast !bound) !(broadcast !bound')) - - let uniformCdf : Tensor [2000, 5] F64 -> Ref $ Tensor [2000, 5] F64 - uniformCdf x = Tensor.(/) (pure x - broadcast !bound) (broadcast !(bound' - bound)) - - iidKolmogorovSmirnov samples uniformCdf - - diff (unsafeEval ksTest) (<) 0.015 - -partial -uniformForNonFiniteBounds : Property -uniformForNonFiniteBounds = property $ do - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) - - let samples = do - bound <- tensor [0.0, 0.0, 0.0, -inf, -inf, -inf, inf, inf, nan] - bound' <- tensor [-inf, inf, nan, -inf, inf, nan, inf, nan, nan] - key <- tensor key - seed <- tensor seed - evalStateT seed !(uniform key !(broadcast bound) !(broadcast bound')) - - samples ===# tensor [-inf, inf, nan, -inf, nan, nan, inf, nan, nan] - -partial -uniformForFiniteEqualBounds : Property -uniformForFiniteEqualBounds = withTests 20 . property $ do - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) - - let bound = tensor [min @{Finite}, -1.0, -1.0e-308, 0.0, 1.0e-308, 1.0, max @{Finite}] - samples = do evalStateT !(tensor seed) !(uniform !(tensor key) !bound !bound) - - samples ===# bound - -partial -uniformSeedIsUpdated : Property -uniformSeedIsUpdated = withTests 20 . property $ do - bound <- forAll (literal [10] doubles) - bound' <- forAll (literal [10] doubles) - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) - - let everything = do - bound <- tensor bound - bound' <- tensor bound' - key <- tensor key - seed <- tensor seed - - rng <- uniform key {shape=[10]} !(broadcast bound) !(broadcast bound') - (seed', sample) <- runStateT seed rng - (seed'', sample') <- runStateT seed' rng - seeds <- concat 0 !(concat 0 seed seed') seed'' - samples <- concat 0 !(expand 0 sample) !(expand 0 sample') - pure (seeds, samples) - - [seed, seed', seed''] = unsafeEval (do (seeds, _) <- everything; pure seeds) - [sample, sample'] = unsafeEval (do (_, samples) <- everything; pure samples) - - diff seed' (/=) seed - diff seed'' (/=) seed' - diff sample' (/=) sample - -partial -uniformIsReproducible : Property -uniformIsReproducible = withTests 20 . property $ do - bound <- forAll (literal [10] doubles) - bound' <- forAll (literal [10] doubles) - key <- forAll (literal [] nats) - seed <- forAll (literal [1] nats) - - let [sample, sample'] = unsafeEval $ do - bound <- tensor bound - bound' <- tensor bound' - key <- tensor key - seed <- tensor seed - - rng <- uniform {shape=[10]} key !(broadcast bound) !(broadcast bound') - sample <- evalStateT seed rng - sample' <- evalStateT seed rng - concat 0 !(expand 0 sample) !(expand 0 sample') - - sample ==~ sample' +namespace F64 + export partial + uniform : Property + uniform = withTests 20 . property $ do + bound <- forAll (literal [5] finiteDoubles) + bound' <- forAll (literal [5] finiteDoubles) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let ksTest = do + let bound = tensor bound + bound' = tensor bound' + bound' = select !(bound' == bound) !(bound' + fill 1.0e-9) !bound' + key <- tensor key + seed <- tensor seed + samples <- evalStateT seed !(uniform key !(broadcast !bound) !(broadcast !bound')) + + let uniformCdf : Tensor [2000, 5] F64 -> Ref $ Tensor [2000, 5] F64 + uniformCdf x = Tensor.(/) (pure x - broadcast !bound) (broadcast !(bound' - bound)) + + iidKolmogorovSmirnov samples uniformCdf + + diff (unsafeEval ksTest) (<) 0.015 + + export partial + uniformForNonFiniteBounds : Property + uniformForNonFiniteBounds = property $ do + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let samples = do + bound <- tensor [0.0, 0.0, 0.0, -inf, -inf, -inf, inf, inf, nan] + bound' <- tensor [-inf, inf, nan, -inf, inf, nan, inf, nan, nan] + key <- tensor key + seed <- tensor seed + evalStateT seed !(uniform key !(broadcast bound) !(broadcast bound')) + + samples ===# tensor [-inf, inf, nan, -inf, nan, nan, inf, nan, nan] + + export partial + uniformForFiniteEqualBounds : Property + uniformForFiniteEqualBounds = withTests 20 . property $ do + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = tensor [min @{Finite}, -1.0, -1.0e-308, 0.0, 1.0e-308, 1.0, max @{Finite}] + samples = do evalStateT !(tensor seed) !(uniform !(tensor key) !bound !bound) + + samples ===# bound + + export partial + uniformSeedIsUpdated : Property + uniformSeedIsUpdated = withTests 20 . property $ do + bound <- forAll (literal [10] doubles) + bound' <- forAll (literal [10] doubles) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let everything = do + bound <- tensor bound + bound' <- tensor bound' + key <- tensor key + seed <- tensor seed + + rng <- uniform key {shape=[10]} !(broadcast bound) !(broadcast bound') + (seed', sample) <- runStateT seed rng + (seed'', sample') <- runStateT seed' rng + seeds <- concat 0 !(concat 0 seed seed') seed'' + samples <- concat 0 !(expand 0 sample) !(expand 0 sample') + pure (seeds, samples) + + [seed, seed', seed''] = unsafeEval (do (seeds, _) <- everything; pure seeds) + [sample, sample'] = unsafeEval (do (_, samples) <- everything; pure samples) + + diff seed' (/=) seed + diff seed'' (/=) seed' + diff sample' (/=) sample + + export partial + uniformIsReproducible : Property + uniformIsReproducible = withTests 20 . property $ do + bound <- forAll (literal [10] doubles) + bound' <- forAll (literal [10] doubles) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let [sample, sample'] = unsafeEval $ do + bound <- tensor bound + bound' <- tensor bound' + key <- tensor key + seed <- tensor seed + + rng <- uniform {shape=[10]} key !(broadcast bound) !(broadcast bound') + sample <- evalStateT seed rng + sample' <- evalStateT seed rng + concat 0 !(expand 0 sample) !(expand 0 sample') + + sample ==~ sample' + +Show (Compare p xs ys) where + show _ = "" + +orderedPair : (shape : Shape) -> Gen (as : Literal shape Nat ** bs : Literal shape Nat ** Compare LT as bs) + +namespace U64 + export partial + uniform : Property + uniform = withTests 20 . property $ do + (lower ** upper ** ordered) <- forAll (orderedPair [10]) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let ksTest = do + samples <- evalStateT !(tensor seed) !(U64.uniform !(tensor key) [lower] [upper]) + + let uniformCdf : Tensor [1, 10] U64 -> Ref $ Tensor [1, 10] F64 + uniformCdf x = do + lower <- castDtype !(tensor {dtype = U64} [lower]) + let upper = castDtype !(tensor {dtype = U64} [upper]) + (castDtype x - pure lower) / (upper - pure lower) + + iidKolmogorovSmirnov samples uniformCdf + + -- samples ===# fill 0 + diff (unsafeEval ksTest) (<) 0.01 +{- + export covering + uniformBoundsAreInclusive : Property + uniformBoundsAreInclusive = property $ do + bound <- forAll (literal [100] nats) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral bound + bound' = bound + fill 2 + key = fromLiteral key + seed = fromLiteral seed + + samples = evalState seed (U64.uniform key bound bound') + + diff (toLiteral samples) (\x, y => any id [| x == y |]) (toLiteral bound) + diff (toLiteral samples) (\x, y => any id [| x == y |]) (toLiteral bound') + + export covering + uniformSeedIsUpdated : Property + uniformSeedIsUpdated = withTests 20 . property $ do + bound <- forAll (literal [10] nats) + bound' <- forAll (literal [10] nats) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral bound + bound' = fromLiteral bound' + key = fromLiteral key + seed = fromLiteral seed + + rng = U64.uniform key {shape=[10]} (broadcast bound) (broadcast bound') + (seed', sample) = runState seed rng + (seed'', sample') = runState seed' rng + + diff (toLiteral seed') (/=) (toLiteral seed) + diff (toLiteral seed'') (/=) (toLiteral seed') + diff (toLiteral sample') (/=) (toLiteral sample) + + export covering + uniformIsReproducible : Property + uniformIsReproducible = withTests 20 . property $ do + bound <- forAll (literal [10] nats) + bound' <- forAll (literal [10] nats) + key <- forAll (literal [] nats) + seed <- forAll (literal [1] nats) + + let bound = fromLiteral bound + bound' = fromLiteral bound' + key = fromLiteral key + seed = fromLiteral seed + + rng = U64.uniform {shape=[10]} key (broadcast bound) (broadcast bound') + sample = evalState seed rng + sample' = evalState seed rng + + sample ===# sample' +-} partial normal : Property @@ -206,11 +292,15 @@ normalIsReproducible = withTests 20 . property $ do export partial all : List (PropertyName, Property) all = [ - ("uniform", uniform) - , ("uniform for infinite and NaN bounds", uniformForNonFiniteBounds) - , ("uniform is not NaN for finite equal bounds", uniformForFiniteEqualBounds) - , ("uniform updates seed", uniformSeedIsUpdated) - , ("uniform produces same samples for same seed", uniformIsReproducible) + ("uniform F64", F64.uniform) + , ("uniform F64 for infinite and NaN bounds", uniformForNonFiniteBounds) + , ("uniform F64 is not NaN for finite equal bounds", uniformForFiniteEqualBounds) + , ("uniform F64 updates seed", F64.uniformSeedIsUpdated) + , ("uniform F64 produces same samples for same seed", F64.uniformIsReproducible) + , ("uniform U64", U64.uniform) + -- , ("uniform U64 bounds are inclusive", uniformBoundsAreInclusive) + -- , ("uniform U64 updates seed", U64.uniformSeedIsUpdated) + -- , ("uniform U64 produces same samples for same seed", U64.uniformIsReproducible) , ("normal", normal) , ("normal updates seed", normalSeedIsUpdated) , ("normal produces same samples for same seed", normalIsReproducible)