Skip to content

Commit

Permalink
jit
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Dec 19, 2023
1 parent ab2f968 commit 15046c4
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ spidr benefits from much of what XLA has to offer, namely the performance benefi

#### Graph generation

This is a high-priority feature but is not yet implemented. spidr can generate new tensor graphs from existing ones. We plan to use this to implement vectorization, just-in-time compilation, and automatic differentiation like JAX's [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) and [`grad`](https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#grad).
This is a high-priority feature but is not yet implemented. spidr can generate new tensor graphs from existing ones. We plan to use this to implement vectorization and automatic differentiation like JAX's [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) and [`grad`](https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#grad).

### Acknowledgements

Expand Down
1 change: 1 addition & 0 deletions backend/.bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.4.0
2 changes: 1 addition & 1 deletion backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.7
0.0.8
15 changes: 15 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/xla_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,21 @@ extern "C" {
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
}

XlaOp* Call(
XlaBuilder* builder,
XlaComputation& computation,
XlaOp* operands,
int operands_len
) {
xla::XlaBuilder* builder_ = reinterpret_cast<xla::XlaBuilder*>(builder);
xla::XlaOp* operands_ = reinterpret_cast<xla::XlaOp*>(operands);
xla::XlaComputation& computation_ = reinterpret_cast<xla::XlaComputation&>(computation);
auto operands_span = absl::Span<const xla::XlaOp>(operands_, operands_len);

xla::XlaOp res = xla::Call(builder_, computation_, operands_span);
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
}

XlaOp* Add(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Add, lhs, rhs); }
XlaOp* Sub(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Sub, lhs, rhs); }
XlaOp* Mul(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Mul, lhs, rhs); }
Expand Down
7 changes: 7 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ extern "C" {
);
XlaOp* Cholesky(XlaOp& a, int lower);

XlaOp* Call(
XlaBuilder* builder,
XlaComputation& computation,
XlaOp* operands,
int operands_len
);

XlaOp* Add(XlaOp& lhs, XlaOp& rhs);
XlaOp* Sub(XlaOp& lhs, XlaOp& rhs);
XlaOp* Mul(XlaOp& lhs, XlaOp& rhs);
Expand Down
107 changes: 66 additions & 41 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,27 @@ Show Err where
show (IndexErr msg) = "IndexErr: \{msg}"

0 Computation : Type -> Type
Computation = StateT (SortedMap Nat XlaOp) (EitherT Err IO)
Computation = StateT (SortedMap Nat XlaComputation, SortedMap Nat XlaOp) (EitherT Err IO)

||| Look up the `XlaOp` at `position` in the graph.
lookup : (position : Nat) -> Computation XlaOp
lookup n = do
cache <- get
(_, cache) <- get
case lookup n cache of
Nothing =>
lift $ left (IndexErr "Tried to look up value at index \{show n} but found keys \{show $ toList (keys cache)}")
lift $ left (IndexErr "Tried to look up XlaOp at index \{show n} but found keys \{show $ toList (keys cache)}")
Just op => pure op

namespace XlaComputation
||| Look up the `XlaComputation` at `position` in the graph.
lookup : (position : Nat) -> Computation XlaComputation
lookup n = do
(cache, _) <- get
case lookup n cache of
Nothing =>
lift $ left (IndexErr "Tried to look up XlaComputation at index \{show n} but found keys \{show $ toList (keys cache)}")
Just comp => pure comp

interpret : XlaBuilder -> Nat -> Env -> Computation XlaOp

||| Build a computation from an inner function
Expand All @@ -94,47 +104,60 @@ buildSub builder name (MkFn params result env) = do
interpretParameter builder (positionInFnParams, positionInGraph, MkShapeAndType shape dtype) = do
xlaShape <- mkShape {dtype} shape
param <- parameter builder positionInFnParams xlaShape name
put $ insert positionInGraph param !get
(comps, ops) <- get
put (comps, insert positionInGraph param ops)

covering
enqueue : XlaBuilder -> Expr -> Computation XlaOp
enqueue builder (FromLiteral {dtype} lit) = constantLiteral builder !(write {dtype} lit)
enqueue _ (Arg x) = lookup x
enqueue builder (Tuple xs) = tuple builder !(traverse lookup xs)
enqueue builder (GetTupleElement idx x) = getTupleElement !(lookup x) idx
enqueue builder (MinValue {dtype}) = minValue {dtype} builder
enqueue builder (MaxValue {dtype}) = maxValue {dtype} builder
enqueue builder (MinFiniteValue {dtype}) = minFiniteValue {dtype} builder
enqueue builder (MaxFiniteValue {dtype}) = maxFiniteValue {dtype} builder
enqueue _ (ConvertElementType x) = convertElementType {dtype = F64} !(lookup x)
enqueue _ (Reshape from to x) = reshape !(lookup x) (range $ length from) to
enqueue _ (Slice starts stops strides x) = slice !(lookup x) starts stops strides
enqueue _ (DynamicSlice starts sizes x) =
enqueue : XlaBuilder -> Env -> Expr -> Computation XlaOp
enqueue builder _ (FromLiteral {dtype} lit) = constantLiteral builder !(write {dtype} lit)
enqueue _ _ (Arg x) = lookup x
enqueue builder _ (Tuple xs) = tuple builder !(traverse lookup xs)
enqueue builder _ (GetTupleElement idx x) = getTupleElement !(lookup x) idx
enqueue builder env (Call f xs) = do
(cachedComps, _) <- get
builtComp <- case lookup f cachedComps of
Just comp => pure comp
Nothing => case findChild env f of
Nothing => lift $ left (IndexErr "Tried to look up child env at index \{show f} but key not found")
Just (_ ** comp) => do
comp <- buildSub builder "name" comp
(comps, ops) <- get
put (insert f comp comps, ops)
pure comp
call builder builtComp !(traverse lookup xs)
enqueue builder _ (MinValue {dtype}) = minValue {dtype} builder
enqueue builder _ (MaxValue {dtype}) = maxValue {dtype} builder
enqueue builder _ (MinFiniteValue {dtype}) = minFiniteValue {dtype} builder
enqueue builder _ (MaxFiniteValue {dtype}) = maxFiniteValue {dtype} builder
enqueue _ _ (ConvertElementType x) = convertElementType {dtype = F64} !(lookup x)
enqueue _ _ (Reshape from to x) = reshape !(lookup x) (range $ length from) to
enqueue _ _ (Slice starts stops strides x) = slice !(lookup x) starts stops strides
enqueue _ _ (DynamicSlice starts sizes x) =
dynamicSlice !(lookup x) !(traverse lookup starts) sizes
enqueue builder (Concat axis x y) = concatInDim builder [!(lookup x), !(lookup y)] (cast axis)
enqueue _ (Diag x) = getMatrixDiagonal !(lookup x)
enqueue _ (Triangle tri x) = triangle !(lookup x) tri
enqueue _ (Transpose ordering x) = transpose !(lookup x) ordering
enqueue builder (Identity {dtype} n) = let n = cast n in identityMatrix {dtype} builder n n
enqueue builder (Broadcast {dtype} from to x) =
enqueue builder _ (Concat axis x y) = concatInDim builder [!(lookup x), !(lookup y)] (cast axis)
enqueue _ _ (Diag x) = getMatrixDiagonal !(lookup x)
enqueue _ _ (Triangle tri x) = triangle !(lookup x) tri
enqueue _ _ (Transpose ordering x) = transpose !(lookup x) ordering
enqueue builder _ (Identity {dtype} n) = let n = cast n in identityMatrix {dtype} builder n n
enqueue builder _ (Broadcast {dtype} from to x) =
if elem 0 to && from /= to
then do
literal <- allocLiteral {dtype} to
constantLiteral builder literal
else
let broadcastDims = map (+ length to `minus` length from) $ range $ length from
in broadcastInDim !(lookup x) to broadcastDims
enqueue builder (Map f xs dims) = do
enqueue builder _ (Map f xs dims) = do
computation <- buildSub builder "computation" f
map builder (toList !(traverse lookup xs)) computation dims
enqueue builder (Reduce f neutral axes x) = do
enqueue builder _ (Reduce f neutral axes x) = do
computation <- buildSub builder "computation" f
reduce !(lookup x) !(lookup neutral) computation axes
enqueue builder (Sort f axis isStable xs) = do
enqueue builder _ (Sort f axis isStable xs) = do
comparator <- buildSub builder "comparator" f
sort !(traverse lookup xs) comparator axis isStable
enqueue _ (Reverse axes x) = rev !(lookup x) axes
enqueue _ (BinaryElementwise f x y) = toXla f !(lookup x) !(lookup y)
enqueue _ _ (Reverse axes x) = rev !(lookup x) axes
enqueue _ _ (BinaryElementwise f x y) = toXla f !(lookup x) !(lookup y)
where
toXla : BinaryOp -> HasIO io => XlaOp -> XlaOp -> io XlaOp
toXla = \case
Expand All @@ -154,7 +177,7 @@ enqueue _ (BinaryElementwise f x y) = toXla f !(lookup x) !(lookup y)
Or => or
Min => min
Max => max
enqueue _ (UnaryElementwise f x) = toXla f !(lookup x)
enqueue _ _ (UnaryElementwise f x) = toXla f !(lookup x)
where
toXla : UnaryOp -> HasIO io => XlaOp -> io XlaOp
toXla = \case
Expand Down Expand Up @@ -182,18 +205,18 @@ enqueue _ (UnaryElementwise f x) = toXla f !(lookup x)
Asinh => asinh
Acosh => acosh
Atanh => atanh
enqueue _ (Argmin {out} axis x) = argMin {outputType=out} !(lookup x) axis
enqueue _ (Argmax {out} axis x) = argMax {outputType=out} !(lookup x) axis
enqueue _ (Select pred true false) = select !(lookup pred) !(lookup true) !(lookup false)
enqueue builder (Cond pred fTrue true fFalse false) = do
enqueue _ _ (Argmin {out} axis x) = argMin {outputType=out} !(lookup x) axis
enqueue _ _ (Argmax {out} axis x) = argMax {outputType=out} !(lookup x) axis
enqueue _ _ (Select pred true false) = select !(lookup pred) !(lookup true) !(lookup false)
enqueue builder _ (Cond pred fTrue true fFalse false) = do
trueComp <- buildSub builder "truthy computation" fTrue
falseComp <- buildSub builder "falsy computation" fFalse
conditional !(lookup pred) !(lookup true) trueComp !(lookup false) falseComp
enqueue _ (Dot l r) = dot !(lookup l) !(lookup r)
enqueue _ (Cholesky x) = cholesky !(lookup x) True
enqueue _ (TriangularSolve a b lower) =
enqueue _ _ (Dot l r) = dot !(lookup l) !(lookup r)
enqueue _ _ (Cholesky x) = cholesky !(lookup x) True
enqueue _ _ (TriangularSolve a b lower) =
triangularSolve !(lookup a) !(lookup b) True lower False NoTranspose
enqueue builder (UniformFloatingPoint key initialState minval maxval shape) = do
enqueue builder _ (UniformFloatingPoint key initialState minval maxval shape) = do
rngOutput <- uniformFloatingPointDistribution
!(lookup key)
!(lookup initialState)
Expand All @@ -202,7 +225,7 @@ enqueue builder (UniformFloatingPoint key initialState minval maxval shape) = do
!(lookup maxval)
!(mkShape {dtype=F64} shape)
tuple builder [value rngOutput, state rngOutput]
enqueue builder (NormalFloatingPoint key initialState shape) = do
enqueue builder _ (NormalFloatingPoint key initialState shape) = do
rngOutput <- normalFloatingPointDistribution
!(lookup key) !(lookup initialState) ThreeFry !(mkShape {dtype=F64} shape)
tuple builder [value rngOutput, state rngOutput]
Expand All @@ -213,20 +236,22 @@ interpret builder root env = do

where
interpretExpr : (Nat, Expr) -> Computation ()
interpretExpr (n, expr) = put (insert n !(enqueue builder expr) !get)
interpretExpr (n, expr) = do
(comps, ops) <- get
put (comps, insert n !(enqueue builder env expr) ops)

export
toString : Nat -> Env -> EitherT Err IO String
toString root env = do
builder <- mkXlaBuilder "toString"
xlaOp <- evalStateT empty (interpret builder root env)
xlaOp <- evalStateT (empty, empty) (interpret builder root env)
pure $ opToString builder xlaOp

export
run : PrimitiveRW dtype a => Nat -> Env -> {shape : _} -> EitherT Err IO (Literal shape a)
run root env = do
builder <- mkXlaBuilder "root"
root <- evalStateT empty (interpret builder root env)
root <- evalStateT (empty, empty) (interpret builder root env)
computation <- XlaBuilder.build builder root
gpuStatus <- validateGPUMachineManager
platform <- if ok gpuStatus then gpuMachineManager else getPlatform "Host"
Expand Down
52 changes: 39 additions & 13 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,24 @@ import Types
import Util

public export
data ShapeAndType : Type where
MkShapeAndType : Shape -> (0 dtype : Type) -> Primitive dtype => ShapeAndType
data Expr : Type where

public export
data Expr : Type where
data Fn : Nat -> Type

-- we use `List (Nat, Expr)` for O(1) append (all we do when building the graph is append)
-- we can't use `(Nat, List Expr)`, or even better `(n ** Vect n Expr)`, because we don't handle
-- we use `List (Nat, a)` for O(1) append (all we do when building the graph is append)
-- we can't use `(Nat, List a)`, or even better `(n ** Vect n a)`, because we don't handle
-- scoping properly so node pointers aren't contiguous and don't match list indices
public export 0
TopSort : Type -> Type
TopSort a = (Nat, List (Nat, a))

export
data Env = MkEnv Nat (List (Nat, Expr))
data Env = MkEnv (TopSort (arity ** Fn arity)) (TopSort Expr)

export
empty : Env
empty = MkEnv 0 []
empty = MkEnv (0, []) (0, [])

export
addNode : Expr -> State Env Nat
Expand All @@ -51,7 +54,15 @@ addNode expr = do

export
toList : Env -> List (Nat, Expr)
toList (MkEnv _ env) = reverse env
toList (MkEnv _ (_, env)) = reverse env

export
findChild : Env -> Nat -> Maybe (a ** Fn a)
findChild (MkEnv (_, children) _) n = lookup n children

public export
data ShapeAndType : Type where
MkShapeAndType : Shape -> (0 dtype : Type) -> Primitive dtype => ShapeAndType

public export
data Fn : Nat -> Type where
Expand Down Expand Up @@ -118,6 +129,13 @@ data Expr : Type where
Arg : Nat -> Expr
Tuple : List Nat -> Expr
GetTupleElement : Nat -> Nat -> Expr

||| Apply a cached function to arguments.
|||
||| @f The function pointer.
||| @xs The function arguments.
Call : (f : Nat) -> (xs : List Nat) -> Expr

MinValue : Primitive dtype => Expr
MaxValue : Primitive dtype => Expr
MinFiniteValue : Primitive dtype => Expr
Expand Down Expand Up @@ -166,17 +184,25 @@ applyN f (x :: xs) = applyN (f x) xs
export
addFn : {arity : _} -> Vect arity ShapeAndType -> FnExpr arity -> State Env (Fn arity)
addFn params f = do
MkEnv next env <- get
let (subEnv@(MkEnv next _), params, result) = runState (MkEnv next []) $ do
MkEnv (nc, children) (next, env) <- get
let (subEnv@(MkEnv (nc, _) (next, _)), params, result) = runState (MkEnv (nc, []) (next, [])) $ do
xs <- traverse addArg params
result <- applyN f xs
pure (zip xs params, result)
put (MkEnv next env)
put (MkEnv (nc, children) (next, env))
pure (MkFn params result subEnv)

where
addArg : ShapeAndType -> State Env Nat
addArg st = do
MkEnv next env <- get
put (MkEnv (S next) ((next, Arg next) :: env))
MkEnv children (next, env) <- get
put (MkEnv children (S next, (next, Arg next) :: env))
pure next

export
shareFn : {arity : _} -> Vect arity ShapeAndType -> FnExpr arity -> State Env Nat
shareFn params f = do
fn <- addFn params f
MkEnv (nc, comps) ops <- get
put (MkEnv (S nc, insert nc (_ ** fn) comps) ops)
pure nc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ export
%foreign (libxla "Cholesky")
prim__cholesky : GCAnyPtr -> Int -> PrimIO AnyPtr

export
%foreign (libxla "Call")
prim__call : GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> Int -> PrimIO AnyPtr

export
%foreign (libxla "Add")
prim__add : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ cholesky (MkXlaOp a) lower = do
opPtr <- onCollectAny opPtr XlaOp.delete
pure (MkXlaOp opPtr)

export
call : HasIO io => XlaBuilder -> XlaComputation -> List XlaOp -> io XlaOp
call (MkXlaBuilder builder) (MkXlaComputation computation) operands = do
MkXlaOpArray operandsXlaOpArrayPtr <- mkXlaOpArray operands
opPtr <- primIO $ prim__call builder computation operandsXlaOpArrayPtr (cast $ length operands)
opPtr <- onCollectAny opPtr XlaOp.delete
pure (MkXlaOp opPtr)

export
add : HasIO io => XlaOp -> XlaOp -> io XlaOp
add = binaryOp prim__add
Expand Down
28 changes: 28 additions & 0 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,34 @@ eval $ MkGraph x = do
Right lit => lit
Left err => idris_crash (show err)

||| `jit` lets you reuse a function without retracing the graph on every call.
|||
||| For example, you could use this to reduce the time it takes to compile a large model loss
||| function, as well as the size in memory of the compiled graph
||| ```
||| while : (condition : Tensor [] F64 -> Tensor [] PRED) ->
||| (f : Tensor shape F64 -> Graph $ Tensor shape F64) ->
||| (start : Tensor shape F64) ->
||| Tensor shape F64
|||
||| sgd : (Tensor shape F64 -> Graph $ Tensor [] F64) -> Graph $ Tensor shape F64
||| sgd loss = do
||| loss <- jit loss
||| -- hmmm ... to use `while` we already pre-compile `f`, so there's no point in `jit` for `while` ... what's a good example then?
||| ```
|||
||| This uses the same mechanism as when sharing tensors, but spidr doesn't do this automatically
||| for functions.
export
jit : Primitive ta =>
-- can we erase sa and sb?
{sa, sb : _} ->
(Tensor sa ta -> Graph $ Tensor sb tb) ->
Graph (Tensor sa ta -> Graph $ Tensor sb tb)
jit {sa, sb} f = do
f <- MkGraph $ shareFn [MkShapeAndType sa ta] (\x => unwrap $ f (MkTensor x))
pure $ the (_ -> Graph _) $ \(MkTensor x) => addTensor (Call f [x])

||| A string representation of the graph used to define a `Tensor`, detailing all enqueued XLA
||| operations.
|||
Expand Down
Loading

0 comments on commit 15046c4

Please sign in to comment.