Skip to content

Commit

Permalink
add stateful uniform sampling of U64
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Jun 10, 2023
1 parent 016c4cd commit a0a86ea
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 144 deletions.
65 changes: 52 additions & 13 deletions backend/src/tensorflow/compiler/xla/client/lib/prng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@ xla::BitGeneratorTy BitGenerator(int bit_generator) {
return bit_generator_;
}

RngOutput* UniformDistribution(
std::function<xla::RngOutput(
xla::XlaOp, xla::XlaOp, xla::BitGeneratorTy, xla::XlaOp, xla::XlaOp, const xla::Shape&
)> f,
XlaOp& key,
XlaOp& initial_state,
int bit_generator,
XlaOp& minval,
XlaOp& maxval,
Shape& shape
) {
auto& key_ = reinterpret_cast<xla::XlaOp&>(key);
auto& initial_state_ = reinterpret_cast<xla::XlaOp&>(initial_state);
xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator);
auto& minval_ = reinterpret_cast<xla::XlaOp&>(minval);
auto& maxval_ = reinterpret_cast<xla::XlaOp&>(maxval);
auto& shape_ = reinterpret_cast<xla::Shape&>(shape);

xla::RngOutput res = f(key_, initial_state_, bit_generator_, minval_, maxval_, shape_);

return new RngOutput {
value: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.value)),
state: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.state))
};
}

extern "C" {
RngOutput* UniformFloatingPointDistribution(
XlaOp& key,
Expand All @@ -45,21 +71,34 @@ extern "C" {
XlaOp& maxval,
Shape& shape
) {
auto& key_ = reinterpret_cast<xla::XlaOp&>(key);
auto& initial_state_ = reinterpret_cast<xla::XlaOp&>(initial_state);
xla::BitGeneratorTy bit_generator_ = BitGenerator(bit_generator);
auto& minval_ = reinterpret_cast<xla::XlaOp&>(minval);
auto& maxval_ = reinterpret_cast<xla::XlaOp&>(maxval);
auto& shape_ = reinterpret_cast<xla::Shape&>(shape);

xla::RngOutput res = xla::UniformFloatingPointDistribution(
key_, initial_state_, bit_generator_, minval_, maxval_, shape_
return UniformDistribution(
xla::UniformFloatingPointDistribution,
key,
initial_state,
bit_generator,
minval,
maxval,
shape
);
}

return new RngOutput {
value: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.value)),
state: reinterpret_cast<XlaOp*>(new xla::XlaOp(res.state))
};
RngOutput* UniformIntDistribution(
XlaOp& key,
XlaOp& initial_state,
int bit_generator,
XlaOp& minval,
XlaOp& maxval,
Shape& shape
) {
return UniformDistribution(
xla::UniformIntDistribution,
key,
initial_state,
bit_generator,
minval,
maxval,
shape
);
}

RngOutput* NormalFloatingPointDistribution(
Expand Down
9 changes: 9 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/lib/prng.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ extern "C" {
Shape& shape
);

RngOutput* UniformIntDistribution(
XlaOp& key,
XlaOp& initial_state,
int bit_generator,
XlaOp& minval,
XlaOp& maxval,
Shape& shape
);

RngOutput* NormalFloatingPointDistribution(
XlaOp& key, XlaOp& initial_state, int bit_generator, Shape& shape
);
Expand Down
10 changes: 10 additions & 0 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ enqueue e@(UniformFloatingPoint key initialState minval maxval shape) = cached e
!(mkShape {dtype=F64} shape)
(builder, _) <- get
tuple builder [value rngOutput, state rngOutput]
enqueue e@(UniformUInt key initialState minval maxval shape) = cached e $ do
rngOutput <- uniformIntDistribution
!(enqueue key)
!(enqueue initialState)
ThreeFry
!(enqueue minval)
!(enqueue maxval)
!(mkShape {dtype=U64} shape)
(builder, _) <- get
tuple builder [value rngOutput, state rngOutput]
enqueue e@(NormalFloatingPoint key initialState shape) = cached e $ do
rngOutput <- normalFloatingPointDistribution
!(enqueue key) !(enqueue initialState) ThreeFry !(mkShape {dtype=F64} shape)
Expand Down
11 changes: 11 additions & 0 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ data Expr : Type where
Cholesky : Expr -> Expr
TriangularSolve : Expr -> Expr -> Bool -> Expr
UniformFloatingPoint : Expr -> Expr -> Expr -> Expr -> Shape -> Expr
UniformUInt : Expr -> Expr -> Expr -> Expr -> Shape -> Expr
NormalFloatingPoint : Expr -> Expr -> Shape -> Expr

export
Expand Down Expand Up @@ -200,6 +201,9 @@ Prelude.Eq Expr where
(UniformFloatingPoint key initialState minval maxval shape) ==
(UniformFloatingPoint key' initialState' minval' maxval' shape') =
key == key' && initialState == initialState' && minval == minval' && maxval == maxval'
(UniformUInt key initialState minval maxval shape) ==
(UniformUInt key' initialState' minval' maxval' shape') =
key == key' && initialState == initialState' && minval == minval' && maxval == maxval'
(NormalFloatingPoint key initialState shape) == (NormalFloatingPoint key' initialState' shape') =
key == key' && initialState == initialState'
_ == _ = False
Expand Down Expand Up @@ -318,6 +322,13 @@ Hashable Expr where
`hashWithSalt` minval
`hashWithSalt` maxval
`hashWithSalt` shape
hashWithSalt salt (UniformUInt key initialState minval maxval shape) = salt
`hashWithSalt` "UniformUInt"
`hashWithSalt` key
`hashWithSalt` initialState
`hashWithSalt` minval
`hashWithSalt` maxval
`hashWithSalt` shape
hashWithSalt salt (NormalFloatingPoint key initialState shape) = salt
`hashWithSalt` "NormalFloatingPoint"
`hashWithSalt` key
Expand Down
15 changes: 10 additions & 5 deletions src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,23 @@ import System.FFI
import Compiler.Xla.Prim.Util

public export
RngOutput : Type
RngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)]
PrimRngOutput : Type
PrimRngOutput = Struct "RngOutput" [("value", AnyPtr), ("state", AnyPtr)]

export
%foreign (libxla "delete_RngOutput")
prim__delete : RngOutput -> PrimIO ()
prim__delete : PrimRngOutput -> PrimIO ()

export
%foreign (libxla "UniformFloatingPointDistribution")
prim__uniformFloatingPointDistribution:
GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO RngOutput
GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput

export
%foreign (libxla "UniformIntDistribution")
prim__uniformIntDistribution:
GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput

export
%foreign (libxla "NormalFloatingPointDistribution")
prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO RngOutput
prim__normalFloatingPointDistribution: GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> PrimIO PrimRngOutput
22 changes: 15 additions & 7 deletions src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/Lib/PRNG.idr
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,41 @@ Cast BitGenerator Int where
cast ThreeFry = 0
cast Philox = 1

%hide Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.PRNG.RngOutput

public export
record RngOutput where
constructor MkRngOutput
value : XlaOp
state : XlaOp

export
uniformFloatingPointDistribution :
uniformDistribution :
(GCAnyPtr -> GCAnyPtr -> Int -> GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO PrimRngOutput) ->
HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput
uniformFloatingPointDistribution
uniformDistribution
f
(MkXlaOp key)
(MkXlaOp initialState)
bitGenerator
(MkXlaOp minval)
(MkXlaOp maxval)
(MkShape shape) = do
rngOutput <- primIO $ prim__uniformFloatingPointDistribution
key initialState (cast bitGenerator) minval maxval shape
rngOutput <- primIO $ f key initialState (cast bitGenerator) minval maxval shape
let value = getField rngOutput "value"
state = getField rngOutput "state"
primIO $ prim__delete rngOutput
value <- onCollectAny value XlaOp.delete
state <- onCollectAny state XlaOp.delete
pure (MkRngOutput {value = MkXlaOp value} {state = MkXlaOp state})

export
uniformFloatingPointDistribution :
HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput
uniformFloatingPointDistribution = uniformDistribution prim__uniformFloatingPointDistribution

export
uniformIntDistribution:
HasIO io => XlaOp -> XlaOp -> BitGenerator -> XlaOp -> XlaOp -> Shape -> io RngOutput
uniformIntDistribution = uniformDistribution prim__uniformIntDistribution

export
normalFloatingPointDistribution :
HasIO io => XlaOp -> XlaOp -> BitGenerator -> Shape -> io RngOutput
Expand Down
115 changes: 73 additions & 42 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -1161,48 +1161,79 @@ Rand = State (Tensor [1] U64)
inf : Tensor [] F64
inf = fromDouble (1.0 / 0.0)

||| Generate independent and identically distributed (IID) uniform samples bounded element-wise
||| between `bound` and `bound'`.
|||
||| `bound` and `bound'` need not be ordered, and samples will be generated, elementwise, in
||| [min bound bound', max bound bound'). The exception is where the bounds are equal, in which
||| case: if the bounds are finite, samples are generated at the common bound, else samples are NaN.
|||
||| The generated samples are a deterministic function of the input key and state, but may vary
||| between backends and library versions.
|||
||| Example usage, multiplying two uniform samples
||| ```
||| x : Tensor [3] F64
||| x = let key = fromLiteral 2
||| rng = uniform key (fill 0.0) (fill 1.0)
||| initialState = fromLiteral [0]
||| in evalState initialState [| rng * rng |]
||| ```
|||
||| @key Determines the stream of generated samples.
||| @bound A bound of the samples. See full docstring for details.
||| @bound' A bound of the samples. See full docstring for details.
export
uniform :
{shape : _} ->
(key : Tensor [] U64) ->
(bound, bound' : Tensor shape F64) ->
Rand (Tensor shape F64)
uniform (MkTensor key) bound bound' =
let minval@(MkTensor minvalExpr) = min bound bound'
maxval@(MkTensor maxvalExpr) = max bound bound'
in ST $ \(MkTensor initialState) =>
let valueState = UniformFloatingPoint key initialState minvalExpr maxvalExpr shape
value = MkTensor $ GetTupleElement 0 valueState
-- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663
-- samples between -inf and 0 should be at -inf, but XLA produces nan
-- similarly, samples in (inf, inf) should be at inf and respectively for -inf
inf = broadcast inf
value = select (minval == - inf && maxval == fill 0) (- inf) value
value = select (minval == inf && maxval == inf) inf value
value = select (minval == - inf && maxval == - inf) (- inf) value
in Id (MkTensor $ GetTupleElement 1 valueState, value)
namespace F64
||| Generate independent and identically distributed (IID) uniform samples bounded element-wise
||| between `bound` and `bound'`. `bound` and `bound'` need not be ordered, and are both
||| inclusive bounds.
|||
||| The generated samples are a deterministic function of the input key and state, but may vary
||| between backends and library versions.
|||
||| Example usage, multiplying two uniform samples
||| ```
||| x : Tensor [3] F64
||| x = let key = fromLiteral 2
||| rng = uniform key (fill 0.0) (fill 1.0)
||| initialState = fromLiteral [0]
||| in evalState initialState [| rng * rng |]
||| ```
|||
||| @key Determines the stream of generated samples.
||| @bound A bound of the samples. See full docstring for details.
||| @bound' A bound of the samples. See full docstring for details.
export
uniform :
{shape : _} ->
(key : Tensor [] U64) ->
(bound, bound' : Tensor shape F64) ->
Rand (Tensor shape F64)
uniform (MkTensor key) bound bound' =
let minval@(MkTensor minvalExpr) = min bound bound'
maxval@(MkTensor maxvalExpr) = max bound bound'
in ST $ \(MkTensor initialState) =>
let valueState = UniformFloatingPoint key initialState minvalExpr maxvalExpr shape
value = MkTensor $ GetTupleElement 0 valueState
-- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663
-- samples between -inf and 0 should be at -inf, but XLA produces nan
-- similarly, samples in (inf, inf) should be at inf and respectively for -inf
inf = broadcast inf
value = select (minval == - inf && maxval == fill 0) (- inf) value
value = select (minval == inf && maxval == inf) inf value
value = select (minval == - inf && maxval == - inf) (- inf) value
in Id (MkTensor $ GetTupleElement 1 valueState, value)

namespace U64
||| Generate independent and identically distributed (IID) uniform samples bounded element-wise
||| between `bound` and `bound'`. `bound` and `bound'` need not be ordered, and are both
||| inclusive bounds.
|||
||| The generated samples are a deterministic function of the input key and state, but may vary
||| between backends and library versions.
|||
||| Example usage, multiplying two uniform samples
||| ```
||| x : Tensor [3] U64
||| x = let key = fromLiteral 2
||| rng = uniform key (fill 0) (fill 100)
||| initialState = fromLiteral [0]
||| in evalState initialState [| rng * rng |]
||| ```
|||
||| @key Determines the stream of generated samples.
||| @bound A bound of the samples. See full docstring for details.
||| @bound' A bound of the samples. See full docstring for details.
export
uniform :
{shape : _} ->
(key : Tensor [] U64) ->
(bound, bound' : Tensor shape U64) ->
Rand (Tensor shape U64)
uniform (MkTensor key) bound bound' =
let MkTensor minval = min bound bound'
MkTensor maxval = max bound bound'
in ST $ \(MkTensor initialState) =>
let valueState = UniformUInt key initialState minval maxval shape
in Id (MkTensor $ GetTupleElement 1 valueState, MkTensor $ GetTupleElement 0 valueState)

||| Generate independent and identically distributed (IID) samples from the standard normal
||| distribution.
Expand Down
Loading

0 comments on commit a0a86ea

Please sign in to comment.