Skip to content

Commit

Permalink
Implement Awhile in schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
ivogabe authored and dpvanbalen committed Oct 23, 2023
1 parent 0c3970b commit 7d9a742
Show file tree
Hide file tree
Showing 8 changed files with 661 additions and 111 deletions.
17 changes: 15 additions & 2 deletions src/Data/Array/Accelerate/AST/Schedule/Uniform.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
module Data.Array.Accelerate.AST.Schedule.Uniform (
UniformSchedule(..), UniformScheduleFun(..),
SArg(..), SArgs, sargVars, sargOutputVars, sargBufferVars,
Input, Output, inputSingle, outputSingle, inputR, outputR, InputOutputR(..),
Input, Output, inputSingle, outputSingle, inputR, outputR,
InputOutputR(..), inputOutputInputR, inputOutputOutputR,
Binding(..), Effect(..),
BaseR(..), BasesR, BaseVar, BaseVars, BLeftHandSide,
Signal(..), SignalResolver(..), Ref(..), OutputRef(..),
Expand Down Expand Up @@ -184,12 +185,24 @@ outputSingle (GroundRscalar (SingleScalarType tp)) = case tp of
-- Relation between input and output
data InputOutputR input output where
InputOutputRsignal :: InputOutputR Signal SignalResolver
InputOutputRref :: InputOutputR (Ref t) (OutputRef t)
InputOutputRref :: GroundR t -> InputOutputR (Ref t) (OutputRef t)
InputOutputRpair :: InputOutputR i1 o1
-> InputOutputR i2 o2
-> InputOutputR (i1, i2) (o1, o2)
InputOutputRunit :: InputOutputR () ()

inputOutputInputR :: InputOutputR input output -> BasesR input
inputOutputInputR InputOutputRsignal = TupRsingle BaseRsignal
inputOutputInputR (InputOutputRref tp) = TupRsingle $ BaseRref tp
inputOutputInputR (InputOutputRpair io1 io2) = inputOutputInputR io1 `TupRpair` inputOutputInputR io2
inputOutputInputR InputOutputRunit = TupRunit

inputOutputOutputR :: InputOutputR input output -> BasesR output
inputOutputOutputR InputOutputRsignal = TupRsingle BaseRsignalResolver
inputOutputOutputR (InputOutputRref tp) = TupRsingle $ BaseRrefWrite tp
inputOutputOutputR (InputOutputRpair io1 io2) = inputOutputOutputR io1 `TupRpair` inputOutputOutputR io2
inputOutputOutputR InputOutputRunit = TupRunit

-- Bindings of instructions which have some return value.
-- They cannot perform side effects.
--
Expand Down
2 changes: 1 addition & 1 deletion src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ bindAwhileIO :: S.InputOutputR input output -> IO (output, input)
bindAwhileIO S.InputOutputRsignal = do
mvar <- newEmptyMVar
return (S.SignalResolver mvar, S.Signal mvar)
bindAwhileIO S.InputOutputRref = do
bindAwhileIO (S.InputOutputRref _) = do
ioref <- newIORef $ internalError "Illegal schedule: Read from ref without value. Some synchronization might be missing."
return (S.OutputRef ioref, S.Ref ioref)
bindAwhileIO (S.InputOutputRpair io1 io2) = do
Expand Down
9 changes: 7 additions & 2 deletions src/Data/Array/Accelerate/Trafo/Operation/LiveVars.hs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ stronglyLiveVariables' liveness returns us = \case
BindLivenessSub subTup' lhsFull lhsSub re' -> case (bnd' re subTup', body' re' s) of
(Left bnd'', Left body'') -> Left $ mkAlet lhsFull us' bnd'' body''
(Left bnd'', Right body'') -> Right $ mkAlet lhsFull us' bnd'' body''
(Right bnd'', Left body'') -> Left $ mkAlet lhsSub (subTupR subTup' us') bnd'' body''
(Right bnd'', Right body'') -> Right $ mkAlet lhsSub (subTupR subTup' us') bnd'' body''
(Right bnd'', Left body'') -> Left $ mkAlet lhsSub (subTupUniqueness subTup' us') bnd'' body''
(Right bnd'', Right body'') -> Right $ mkAlet lhsSub (subTupUniqueness subTup' us') bnd'' body''
Alloc shr tp sh
| free <- IdxSet.fromVars sh
, liveness1 <- returnIndices returns free liveness ->
Expand Down Expand Up @@ -366,3 +366,8 @@ composeSubArgs (SubArgsLive SubArgKeep s1) (SubArgsLive s s2) =
composeSubArgs (SubArgsLive (SubArgOut t) s1) (SubArgsLive SubArgKeep s2) = SubArgsLive (SubArgOut t) $ composeSubArgs s1 s2
composeSubArgs (SubArgsLive (SubArgOut t1) s1) (SubArgsLive (SubArgOut t2) s2) = SubArgsLive (SubArgOut $ composeSubTupR t2 t1) $ composeSubArgs s1 s2

subTupUniqueness :: SubTupR t t' -> Uniquenesses t -> Uniquenesses t'
subTupUniqueness SubTupRskip _ = TupRunit
subTupUniqueness SubTupRkeep t = t
subTupUniqueness (SubTupRpair s1 s2) (TupRpair t1 t2) = subTupUniqueness s1 t1 `TupRpair` subTupUniqueness s2 t2
subTupUniqueness (SubTupRpair s1 s2) (TupRsingle Shared) = TupRsingle $ Shared
2 changes: 2 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Operation/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -535,4 +535,6 @@ subTupSubstitution (SubTupRpair s1 s2) (LeftHandSidePair l1 l2) (TupRpair v1 v2)
, Exists l2'' <- rebuildLHS l2
, SubTupSubstitution l2' k2 <- subTupSubstitution s2 l2'' (mapTupR (weaken $ weakenWithLHS l1') v2)
= SubTupSubstitution (LeftHandSidePair l1' l2') (k2 .> sinkWithLHS l2 l2'' k1)
subTupSubstitution s (LeftHandSideWildcard t) _
= SubTupSubstitution (LeftHandSideWildcard $ subTupR s t) weakenId
subTupSubstitution _ _ _ = internalError "Tuple mismatch"
145 changes: 99 additions & 46 deletions src/Data/Array/Accelerate/Trafo/Schedule/Partial.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ module Data.Array.Accelerate.Trafo.Schedule.Partial (
Parallelism(..),
Sync(..), SyncEnv,
SyncSchedule(..), SyncScheduleFun,
Loop,
variablesToSyncEnv,

UnitTuple, UpdateTuple(..),
UnitTuple, UpdateTuple(..), toPartialReturn,

compileKernel', CompiledKernel(..),
) where
Expand All @@ -50,6 +52,8 @@ import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Trafo.Var
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Trafo.Exp.Substitution
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.AST.IdxSet as IdxSet
Expand Down Expand Up @@ -164,12 +168,25 @@ data PrePartialSchedule schedule kernel env t where

PAwhile
:: Uniquenesses t
-> PrePartialScheduleFun schedule kernel env (t -> PrimBool)
-> PrePartialScheduleFun schedule kernel env (t -> t)
-> PrePartialScheduleFun schedule kernel env (t -> Loop t)
-> GroundVars env t
-> PrePartialSchedule schedule kernel env t

-- Signals that this while loop should do another iteration.
-- The state for the next iteration is computed by the subterm.
PContinue
:: schedule kernel env t
-> PrePartialSchedule schedule kernel env (Loop t)

-- Stops this while loop. The result value of the loop is
-- given by the GroundVars.
PBreak
:: Uniquenesses t
-> GroundVars env t
-> PrePartialSchedule schedule kernel env (Loop t)

data Void' a where
data Loop a where

data UpdateTuple env t1 t2 where
UpdateKeep :: UpdateTuple env t t
Expand Down Expand Up @@ -396,11 +413,18 @@ toPartial' us = \case
, (falseFree, false') <- toPartial' us false ->
( trueFree `IdxSet.union` falseFree
, PartialSchedule $ PAcond var true' false' )
C.Awhile us' cond step initial
| (condFree, cond') <- toPartialUnary (TupRsingle Shared) cond
, (stepFree, step') <- toPartialUnary us' step ->
( condFree `IdxSet.union` stepFree `IdxSet.union` IdxSet.fromList (groundBufferVars initial)
, PartialSchedule $ PAwhile us' cond' step' initial )
C.Awhile us' (C.Alam lhsC (C.Abody cond)) (C.Alam lhsS (C.Abody step)) initial
| DeclareVars lhs _ vars <- declareVars $ lhsToTupR lhsC
, (condFree, cond') <- toPartial' (TupRsingle Shared) $ weaken (weakenToFullLHS lhs lhsC) cond
, (stepFree, step') <- toPartial' (TupRsingle Shared) $ weaken (weakenSucc' $ weakenToFullLHS lhs lhsS) step
, fn <- PartialSchedule $ PLet Sequential (LeftHandSideSingle (GroundRscalar scalarType)) (TupRsingle Shared) cond'
$ PartialSchedule $ PAcond (Var scalarType ZeroIdx)
(PartialSchedule $ PContinue step')
(PartialSchedule $ PBreak us' $ vars (weakenSucc weakenId))
->
( IdxSet.drop' lhs condFree `IdxSet.union` IdxSet.drop' lhs (IdxSet.drop stepFree) `IdxSet.union` IdxSet.fromList (groundBufferVars initial)
, PartialSchedule $ PAwhile us' (Plam lhs $ Pbody fn) initial )
C.Awhile{} -> internalError "Unary function impossible"
where
-- For all simple cases, with no free buffer variables.
simple :: PrePartialSchedule PartialSchedule kernel env t -> (IdxSet env, PartialSchedule kernel env t)
Expand All @@ -414,6 +438,21 @@ toPartial' us = \case
instrToSync (Index v) = Just $ Exists $ varIdx v
instrToSync (Parameter _) = Nothing -- Parameter is a scalar variable

weakenToFullLHS
-- Full LHS
:: GLeftHandSide t env envFull
-- Sub LHS
-> GLeftHandSide t env envSub
-> envSub :> envFull
weakenToFullLHS = \lhs1 lhs2 -> go lhs1 lhs2 weakenId
where
go :: GLeftHandSide t env1 env1' -> GLeftHandSide t env2 env2' -> env2 :> env1 -> env2' :> env1'
go lhs (LeftHandSideWildcard _) k = weakenWithLHS lhs .> k
go (LeftHandSidePair lhs1 lhs1') (LeftHandSidePair lhs2 lhs2') k = go lhs1' lhs2' $ go lhs1 lhs2 k
go (LeftHandSideSingle _) (LeftHandSideSingle _) k = weakenKeep k
go (LeftHandSideWildcard _) _ _ = internalError "Expected second LHS to be contained in first LHS"
go _ _ _ = internalError "LHS mismatch"

returnValues
:: UpdateTuple env t1 t2
-> PartialSchedule kernel env t1
Expand All @@ -431,20 +470,6 @@ toPartialReturn (TupRsingle u) (TupRsingle var)
= UpdateSet u var
toPartialReturn _ _ = internalError "Tuple mismatch"

toPartialUnary
:: IsKernel kernel
=> Uniquenesses t
-> C.PartitionedAfun (KernelOperation kernel) env (s -> t)
-> (IdxSet env, PartialScheduleFun kernel env (s -> t))
toPartialUnary us (C.Alam lhs (C.Abody body)) =
( IdxSet.drop' lhs free
, Plam lhs $ Pbody body'
)
where
(free, body') = toPartial' us body

toPartialUnary _ _ = internalError "Expected unary function"

groundBufferVars :: GroundVars env t -> [Exists (Idx env)]
groundBufferVars = (`go` [])
where
Expand Down Expand Up @@ -474,7 +499,9 @@ rebuild' (PartialSchedule schedule) = case schedule of
PReturnEnd tup -> buildReturnEnd tup
PReturnValues updateTup next -> buildReturnValues updateTup (rebuild' next)
PAcond var true false -> buildAcond var (rebuild' true) (rebuild' false)
PAwhile us cond step initial -> buildAwhile us (rebuildUnary cond) (rebuildUnary step) initial
PAwhile us fn initial -> buildAwhile us (rebuildUnary fn) initial
PContinue next -> buildContinue $ rebuild' next
PBreak us vars -> buildBreak us vars

rebuildUnary :: PartialScheduleFun kernel env f -> BuildUnary kernel env f
rebuildUnary (Plam lhs (Pbody body)) = BuildUnary lhs $ rebuild' body
Expand Down Expand Up @@ -682,26 +709,52 @@ data BuildUnary kernel env f where

buildAwhile
:: Uniquenesses t
-> BuildUnary kernel env (t -> PrimBool)
-> BuildUnary kernel env (t -> t)
-> BuildUnary kernel env (t -> Loop t)
-> GroundVars env t
-> Build PartialSchedule kernel env t
buildAwhile us (BuildUnary condLhs cond') (BuildUnary stepLhs step') initial available =
buildAwhile us (BuildUnary lhs fn') initial available =
Built{
didChange = didChange cond || didChange step,
directlyAwaits = IdxSet.fromVarList (lhsTake condLhs (directlyAwaits cond) initial) `IdxSet.union` IdxSet.drop' condLhs (directlyAwaits cond),
writes = IdxSet.drop' condLhs (writes cond) `IdxSet.union` IdxSet.drop' stepLhs (writes step),
didChange = didChange fn,
directlyAwaits = IdxSet.fromVarList (lhsTake lhs (directlyAwaits fn) initial) `IdxSet.union` IdxSet.drop' lhs (directlyAwaits fn),
writes = IdxSet.drop' lhs (writes fn),
finallyReleases = IdxSet.empty,
trivial = False,
term = PartialSchedule $ PAwhile
us
(Plam condLhs $ Pbody $ term cond)
(Plam stepLhs $ Pbody $ term step)
(Plam lhs $ Pbody $ term fn)
initial
}
where
cond = cond' (IdxSet.skip' condLhs available)
step = step' (IdxSet.skip' stepLhs available)
fn = fn' (IdxSet.skip' lhs available)

buildContinue
:: Build PartialSchedule kernel env t
-> Build PartialSchedule kernel env (Loop t)
buildContinue next' available =
Built{
didChange = didChange next,
directlyAwaits = IdxSet.empty,
writes = writes next,
finallyReleases = finallyReleases next,
trivial = False,
term = PartialSchedule $ PContinue $ term next
}
where
next = next' available

buildBreak
:: Uniquenesses t
-> GroundVars env t
-> Build PartialSchedule kernel env (Loop t)
buildBreak us vars _ =
Built{
didChange = False,
directlyAwaits = IdxSet.fromVars vars,
writes = IdxSet.empty,
finallyReleases = IdxSet.empty,
trivial = False,
term = PartialSchedule $ PBreak us vars
}

updateCount :: UpdateTuple env t1 t2 -> Count
updateCount UpdateKeep = Zero
Expand Down Expand Up @@ -787,28 +840,28 @@ analyseSyncEnv' (PartialSchedule sched) = case sched of
(unionPartialEnv max (syncEnv true') (syncEnv false'))
False
(PAcond var true' false')
PAwhile us cond step initial ->
PAwhile us (Plam lhs (Pbody body)) initial ->
let
(condEnv, cond') = fun cond
(stepEnv, step') = fun step
body' = analyseSyncEnv body
in
ToSyncSchedule UpdateKeep $
SyncSchedule
(unionPartialEnv max condEnv $ unionPartialEnv max stepEnv $ variablesToSyncEnv us initial)
(unionPartialEnv max (weakenSyncEnv lhs $ syncEnv body') $ variablesToSyncEnv us initial)
False
(PAwhile us cond' step' initial)
(PAwhile us (Plam lhs $ Pbody body') initial)
PAwhile{} -> internalError "Function impossible"
PContinue next ->
let
next' = analyseSyncEnv next
in
ToSyncSchedule UpdateKeep $
SyncSchedule (syncEnv next') False $ PContinue next'
PBreak us vars ->
ToSyncSchedule UpdateKeep $ SyncSchedule (variablesToSyncEnv us vars) False $ PBreak us vars
where
noBuffers :: PrePartialSchedule SyncSchedule kernel env t -> ToSyncSchedule kernel env t
noBuffers = ToSyncSchedule UpdateKeep . SyncSchedule PEnd True

fun :: PartialScheduleFun kernel env f -> (SyncEnv env, SyncScheduleFun kernel env f)
fun (Pbody body) = (syncEnv body', Pbody body')
where
body' = analyseSyncEnv body
fun (Plam lhs f) = (weakenSyncEnv lhs env, Plam lhs f')
where
(env, f') = fun f

variablesToSyncEnv :: Uniquenesses t -> GroundVars genv t -> SyncEnv genv
variablesToSyncEnv uniquenesses vars = partialEnvFromList noCombine $ go uniquenesses vars []
where
Expand Down
Loading

0 comments on commit 7d9a742

Please sign in to comment.