Skip to content

Commit

Permalink
add stateful uniform sampling of U64
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Jun 10, 2023
1 parent 5d31e9c commit ba5cee1
Show file tree
Hide file tree
Showing 9 changed files with 384 additions and 174 deletions.
65 changes: 52 additions & 13 deletions backend/src/tensorflow/compiler/xla/client/lib/prng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@ xla::BitGeneratorTy BitGenerator(int bit_generator) {
return bit_generator_;
}

RngOutput* UniformDistribution(
std::function<xla::RngOutput(
xla::XlaOp, xla::XlaOp, xla::BitGeneratorTy, xla::XlaOp, xla::XlaOp, const xla::Shape&
)> f,
XlaOp& key,
XlaOp& initial_state,
int bit_generator,
XlaOp& minval,
XlaOp& maxval,
Shape& shape
) {
auto& key_ = reinterpret_cast<xla::XlaOp&>(key);
auto& initial_state_ = reinterpret_cast<xla::XlaOp&>(initial_state);
xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator);
auto& minval_ = reinterpret_cast<xla::XlaOp&>(minval);
auto& maxval_ = reinterpret_cast<xla::XlaOp&>(maxval);
auto& shape_ = reinterpret_cast<xla::Shape&>(shape);

xla::RngOutput res = f(key_, initial_state_, bit_generator_, minval_, maxval_, shape_);

return new RngOutput {
value: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.value)),
state: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.state))
};
}

extern "C" {
RngOutput* UniformFloatingPointDistribution(
XlaOp& key,
Expand All @@ -45,21 +71,34 @@ extern "C" {
XlaOp& maxval,
Shape& shape
) {
auto& key_ = reinterpret_cast<xla::XlaOp&>(key);
auto& initial_state_ = reinterpret_cast<xla::XlaOp&>(initial_state);
xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator);
auto& minval_ = reinterpret_cast<xla::XlaOp&>(minval);
auto& maxval_ = reinterpret_cast<xla::XlaOp&>(maxval);
auto& shape_ = reinterpret_cast<xla::Shape&>(shape);

xla::RngOutput res = xla::UniformFloatingPointDistribution(
key_, initial_state_, bit_generator_, minval_, maxval_, shape_
return UniformDistribution(
xla::UniformFloatingPointDistribution,
key,
initial_state,
bit_generator,
minval,
maxval,
shape
);
}

return new RngOutput {
value: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.value)),
state: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.state))
};
RngOutput* UniformIntDistribution(
XlaOp& key,
XlaOp& initial_state,
int bit_generator,
XlaOp& minval,
XlaOp& maxval,
Shape& shape
) {
return UniformDistribution(
xla::UniformIntDistribution,
key,
initial_state,
bit_generator,
minval,
maxval,
shape
);
}

RngOutput* NormalFloatingPointDistribution(
Expand Down
9 changes: 9 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/lib/prng.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ extern "C" {
Shape& shape
);

RngOutput* UniformIntDistribution(
XlaOp& key,
XlaOp& initial_state,
int bit_generator,
XlaOp& minval,
XlaOp& maxval,
Shape& shape
);

RngOutput* NormalFloatingPointDistribution(
XlaOp& key, XlaOp& initial_state, int bit_generator, Shape& shape
);
Expand Down
9 changes: 9 additions & 0 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ enqueue builder (UniformFloatingPoint key initialState minval maxval shape) = do
!(lookup maxval)
!(mkShape {dtype=F64} shape)
tuple builder [value rngOutput, state rngOutput]
enqueue builder (UniformUInt key initialState minval maxval shape) = do
rngOutput <- uniformIntDistribution
!(lookup key)
!(lookup initialState)
ThreeFry
!(lookup minval)
!(lookup maxval)
!(mkShape {dtype=U64} shape)
tuple builder [value rngOutput, state rngOutput]
enqueue builder (NormalFloatingPoint key initialState shape) = do
rngOutput <- normalFloatingPointDistribution
!(lookup key) !(lookup initialState) ThreeFry !(mkShape {dtype=F64} shape)
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,5 @@ data Expr : Type where
Cholesky : Nat -> Expr
TriangularSolve : Nat -> Nat -> Bool -> Expr
UniformFloatingPoint : Nat -> Nat -> Nat -> Nat -> Shape -> Expr
UniformUInt : Nat -> Nat -> Nat -> Nat -> Shape -> Expr
NormalFloatingPoint : Nat -> Nat -> Shape -> Expr
15 changes: 10 additions & 5 deletions src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,23 @@ import System.FFI
import Compiler.Xla.Prim.Util

public export
RngOutput : Type
RngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)]
PrimRngOutput : Type
PrimRngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)]

export
%foreign (libxla "delete_RngOutput")
prim__delete : RngOutput -> PrimIO ()
prim__delete : PrimRngOutput -> PrimIO ()

export
%foreign (libxla "UniformFloatingPointDistribution")
prim__uniformFloatingPointDistribution:
GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO RngOutput
GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput

export
%foreign (libxla "UniformIntDistribution")
prim__uniformIntDistribution:
GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput

export
%foreign (libxla "NormalFloatingPointDistribution")
prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO RngOutput
prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO PrimRngOutput
22 changes: 15 additions & 7 deletions src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,41 @@ Cast BitGenerator Int where
cast ThreeFry = 0
cast Philox = 1

%hide Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.PRNG.RngOutput

public export
record RngOutput where
constructor MkRngOutput
value : XlaOp
state : XlaOp

export
uniformFloatingPointDistribution :
uniformDistribution :
(GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput) ->
HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput
uniformFloatingPointDistribution
uniformDistribution
f
(MkXlaOp key)
(MkXlaOp initialState)
bitGenerator
(MkXlaOp minval)
(MkXlaOp maxval)
(MkShape shape) = do
rngOutput <- primIO $ prim__uniformFloatingPointDistribution
key initialState (cast bitGenerator) minval maxval shape
rngOutput <- primIO $ f key initialState (cast bitGenerator) minval maxval shape
let value = getField rngOutput "value"
state = getField rngOutput "state"
primIO $ prim__delete rngOutput
value <- onCollectAny value XlaOp.delete
state <- onCollectAny state XlaOp.delete
pure (MkRngOutput {value = MkXlaOp value} {state = MkXlaOp state})

export
uniformFloatingPointDistribution :
HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput
uniformFloatingPointDistribution = uniformDistribution prim__uniformFloatingPointDistribution

export
uniformIntDistribution:
HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput
uniformIntDistribution = uniformDistribution prim__uniformIntDistribution

export
normalFloatingPointDistribution :
HasIO io => XlaOp -> XlaOp -> BitGenerator -> Shape -> io RngOutput
Expand Down
10 changes: 10 additions & 0 deletions src/Literal.idr
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ data Literal : Shape -> Type -> Type where
Nil : Literal (0 :: ds) a
(::) : Literal ds a -> Literal (d :: ds) a -> Literal (S d :: ds) a

export
(.shape) : (xs : Literal shape a) -> (s ** s = shape)

export
fromInteger : Integer -> Literal [] Int32
fromInteger = Scalar . cast {to=Int32}
Expand Down Expand Up @@ -210,3 +213,10 @@ namespace All
Scalar : forall x . p x -> All p (Scalar x)
Nil : All p []
(::) : All p x -> All p xs -> All p (x :: xs)

namespace Compare
public export
data Compare : (0 p : a -> b -> Type) -> Literal shape a -> Literal shape b -> Type where
Scalar : forall a, b. p a b -> Compare p (Scalar a) (Scalar b)
Nil : Compare p [] []
(::) : Compare p as bs -> Compare p ass bss -> Compare p (as :: ass) (bs :: bss)
131 changes: 85 additions & 46 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -1393,52 +1393,91 @@ Rand = StateT (Tensor [1] U64) Ref
inf : Ref $ Tensor [] F64
inf = fromDouble (1.0 / 0.0)

||| Generate independent and identically distributed (IID) uniform samples bounded element-wise
||| between `bound` and `bound'`.
|||
||| `bound` and `bound'` need not be ordered, and samples will be generated, elementwise, in
||| [min !bound !bound', max !bound !bound'). The exception is where the bounds are equal, in which
||| case: if the bounds are finite, samples are generated at the common bound, else samples are NaN.
|||
||| The generated samples are a deterministic function of the input key and state, but may vary
||| between backends and library versions.
|||
||| Example usage, multiplying two uniform samples
||| ```
||| x : Ref $ Tensor [3] F64
||| x = do key <- tensor (Scalar 2)
||| rng <- uniform key !(fill 0.0) !(fill 1.0)
||| initialState <- tensor [Scalar 0]
||| evalStateT initialState (do lift $ pure !rng * pure !rng)
||| ```
|||
||| @key Determines the stream of generated samples.
||| @bound A bound of the samples. See full docstring for details.
||| @bound' A bound of the samples. See full docstring for details.
export
uniform :
{shape : _} ->
(key : Tensor [] U64) ->
(bound, bound' : Tensor shape F64) ->
Ref $ Rand $ Tensor shape F64
uniform (MkTensor iKey envKey) bound bound' = do
minval@(MkTensor iMinval envMinval) <- min bound bound'
maxval@(MkTensor iMaxval envMaxval) <- max bound bound'
let inf = broadcast !inf
let env = mergeLeft (mergeLeft envKey envMinval) envMaxval
pure $ ST $ \(MkTensor iState envState) => do
i <- new
let env = mergeLeft envState env
env = insert i (UniformFloatingPoint iKey iState iMinval iMaxval shape) env
state = env `end` GetTupleElement 1 i
value = env `end` GetTupleElement 0 i
-- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663
-- samples between -inf and 0 should be at -inf, but XLA produces nan
-- similarly, samples in (inf, inf) should be at inf and respectively for -inf
value = select !((pure minval == - inf) && (pure maxval == fill 0)) !(- inf) !value
value = select !((pure minval == inf) && (pure maxval == inf)) !inf !value
value = select !((pure minval == - inf) && (pure maxval == - inf)) !(- inf) !value
pure (!state, !value)
namespace F64
||| Generate independent and identically distributed (IID) uniform samples bounded element-wise
||| between `bound` and `bound'`.
|||
||| `bound` and `bound'` need not be ordered, and samples will be generated, elementwise, in
||| [min !bound !bound', max !bound !bound'). The exception is where the bounds are equal, in which
||| case: if the bounds are finite, samples are generated at the common bound, else samples are NaN.
|||
||| The generated samples are a deterministic function of the input key and state, but may vary
||| between backends and library versions.
|||
||| Example usage, multiplying two uniform samples
||| ```
||| x : Ref $ Tensor [3] F64
||| x = do key <- tensor (Scalar 2)
||| rng <- uniform key !(fill 0.0) !(fill 1.0)
||| initialState <- tensor [Scalar 0]
||| evalStateT initialState (do lift $ pure !rng * pure !rng)
||| ```
|||
||| @key Determines the stream of generated samples.
||| @bound A bound of the samples. See full docstring for details.
||| @bound' A bound of the samples. See full docstring for details.
export
uniform :
(key : Tensor [] U64) ->
(bound, bound' : Tensor shape F64) ->
Ref $ Rand $ Tensor shape F64
uniform (MkTensor iKey envKey) bound bound' = do
minval@(MkTensor {shape = _} iMinval envMinval) <- min bound bound'
maxval@(MkTensor iMaxval envMaxval) <- max bound bound'
let inf = broadcast !inf
let env = mergeLeft (mergeLeft envKey envMinval) envMaxval
pure $ ST $ \(MkTensor iState envState) => do
i <- new
let env = mergeLeft envState env
env = insert i (UniformFloatingPoint iKey iState iMinval iMaxval shape) env
state = env `end` GetTupleElement 1 i
value = env `end` GetTupleElement 0 i
-- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663
-- samples between -inf and 0 should be at -inf, but XLA produces nan
-- similarly, samples in (inf, inf) should be at inf and respectively for -inf
value = select !((pure minval == - inf) && (pure maxval == fill 0)) !(- inf) !value
value = select !((pure minval == inf) && (pure maxval == inf)) !inf !value
value = select !((pure minval == - inf) && (pure maxval == - inf)) !(- inf) !value
pure (!state, !value)

namespace U64
||| Generate independent and identically distributed (IID) uniform samples bounded element-wise
||| between `bound` and `bound'`. `bound` and `bound'` need not be ordered, and are both
||| inclusive bounds.
|||
||| The generated samples are a deterministic function of the input key and state, but may vary
||| between backends and library versions.
|||
||| Example usage, multiplying two uniform samples
||| ```
||| x : Tensor [3] U64
||| x = let key = fromLiteral 2
||| rng = uniform key (fill 0) (fill 100)
||| initialState = fromLiteral [0]
||| in evalState initialState [| rng * rng |]
||| ```
|||
||| @key Determines the stream of generated samples.
||| @bound A bound of the samples. See full docstring for details.
||| @bound' A bound of the samples. See full docstring for details.
export
uniform :
(key : Tensor [] U64) ->
(lower, upper : Literal shape Nat) ->
{auto 0 boundsOrdered : Compare LT lower upper} ->
Ref $ Rand $ Tensor shape U64
uniform (MkTensor iKey envKey) lower upper = do
let (shape' ** shapesEq) = lower.shape
MkTensor iLower envLower <- tensor {dtype = U64} {shape = shape'} (rewrite shapesEq in lower)
MkTensor iUpper envUpper <- tensor {dtype = U64} {shape = shape'} (rewrite shapesEq in upper)
let env = mergeLeft (mergeLeft envKey envLower) envUpper
pure $ ST $ \(MkTensor iState envState) => do
i <- new
let env = mergeLeft envState env
env = insert i (UniformUInt iKey iState iLower iUpper shape') env
state = env `end` GetTupleElement 1 i
value = env `end` GetTupleElement 0 i
pure (!state, rewrite sym shapesEq in !value)

||| Generate independent and identically distributed (IID) samples from the standard normal
||| distribution.
Expand Down
Loading

0 comments on commit ba5cee1

Please sign in to comment.