Skip to content

Commit

Permalink
Keep track of already resolved signals
Browse files Browse the repository at this point in the history
  • Loading branch information
ivogabe committed Oct 25, 2023
1 parent 92a4e9f commit 9d5d024
Showing 1 changed file with 47 additions and 32 deletions.
79 changes: 47 additions & 32 deletions src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -681,14 +681,20 @@ data BuildSchedule kernel env =
directlyAwaits :: [Idx env Signal],
trivial :: Bool,
-- Constructs a schedule, but doesn't wait on the directlyAwaits signals.
-- construct' adds that.
-- constructFull adds that.
construct
:: forall env'.
env :> env'
-> BuildEnv env'
-> Continuation kernel env'
-> UniformSchedule kernel env'
}

data BuildEnv env =
BuildEnv{
buildEnvResolved :: IdxSet env -- Set of resolved signals
}

instance Sink' (BuildSchedule kernel) where
weaken' k schedule =
BuildSchedule{
Expand All @@ -706,17 +712,22 @@ newtype BuildScheduleFun kernel env t =
}

-- Constructs a schedule, and waits on the directlyAwaits signals.
construct'
constructFull
:: BuildSchedule kernel env
-> env :> env'
-> BuildEnv env'
-> Continuation kernel env'
-> UniformSchedule kernel env'
construct' schedule k cont
| null $ directlyAwaits schedule = term
| otherwise =
Effect (SignalAwait $ map (weaken k) $ directlyAwaits schedule) term
where
term = construct schedule k cont
constructFull schedule k env cont
| null $ directlyAwaits schedule = construct schedule k env cont
| signals' <-
-- Don't wait on already resolved signals
filter (\idx -> not (idx `IdxSet.member` buildEnvResolved env))
$ map (weaken k)
$ directlyAwaits schedule
, env' <- BuildEnv $ IdxSet.union (IdxSet.fromList' signals') $ buildEnvResolved env
= (if null signals' then id else Effect $ SignalAwait signals')
$ construct schedule k env' cont

data Continuation kernel env where
ContinuationEnd
Expand All @@ -737,9 +748,9 @@ buildReturn :: BuildSchedule kernel env
buildReturn = BuildSchedule{
directlyAwaits = [],
trivial = True,
construct = \_ -> \case
construct = \_ env -> \case
ContinuationEnd -> Return
ContinuationDo k2 build k3 cont -> construct' build k2 $ weaken' k3 cont
ContinuationDo k2 build k3 cont -> constructFull build k2 env $ weaken' k3 cont
}

buildLet
Expand All @@ -765,16 +776,21 @@ buildLet lhs binding body
constructLet
:: Bool
-> env1 :> env1'
-> BuildEnv env1'
-> Continuation kernel env1'
-> UniformSchedule kernel env1'
constructLet shouldAwait k cont
constructLet shouldAwait k env cont
| Exists lhs' <- rebuildLHS lhs
, k' <- sinkWithLHS lhs lhs' k =
Alet lhs' (weaken k binding)
$ (if shouldAwait then Effect $ SignalAwait $ map (weaken k') $ directlyAwaits body else id)
$ construct body k'
, k' <- sinkWithLHS lhs lhs' k
, binding' <- weaken k binding =
Alet lhs' binding'
$ constructFull (if shouldAwait then body else body{ directlyAwaits = [] }) k'
(buildEnvExtend lhs' binding' env)
$ weaken' (weakenWithLHS lhs') cont

buildEnvExtend :: BLeftHandSide t env1 env2 -> Binding env1 t -> BuildEnv env1 -> BuildEnv env2
buildEnvExtend lhs _ (BuildEnv resolved) = BuildEnv $ IdxSet.skip' lhs resolved

buildEffect
:: Effect kernel env
-> BuildSchedule kernel env
Expand All @@ -791,18 +807,17 @@ buildEffect effect next
BuildSchedule{
directlyAwaits = directlyAwaits next,
trivial = trivialEffect effect && trivial next,
construct = \k cont ->
construct = \k env cont ->
Effect (weaken' k effect)
$ construct next k cont
$ construct next k env cont
}
| otherwise =
BuildSchedule{
directlyAwaits = [],
trivial = False,
construct = \k cont ->
construct = \k env cont ->
Effect (weaken' k effect)
$ Effect (SignalAwait $ map (weaken k) $ directlyAwaits next)
$ construct next k cont
$ constructFull next k env cont
}
where
-- Write may be postponed: a write doesn't do synchronisation,
Expand All @@ -816,8 +831,8 @@ buildSeq a b =
BuildSchedule {
directlyAwaits = directlyAwaits a,
trivial = trivial a && trivial b && directlyAwaits b `isSubsequenceOf` directlyAwaits a,
construct = \k cont ->
construct a k $ ContinuationDo k b weakenId cont
construct = \k env cont ->
construct a k env $ ContinuationDo k b weakenId cont
}

buildSpawn :: BuildSchedule kernel env -> BuildSchedule kernel env -> BuildSchedule kernel env
Expand All @@ -830,10 +845,10 @@ buildSpawn a b
BuildSchedule{
directlyAwaits = directlyAwaits a `sortedIntersection` directlyAwaits b,
trivial = False,
construct = \k cont ->
construct = \k env cont ->
Spawn
(construct' a{directlyAwaits = directlyAwaits a `sortedMinus` directlyAwaits b} k cont)
(construct' b{directlyAwaits = directlyAwaits b `sortedMinus` directlyAwaits a} k ContinuationEnd)
(constructFull a{directlyAwaits = directlyAwaits a `sortedMinus` directlyAwaits b} k env cont)
(constructFull b{directlyAwaits = directlyAwaits b `sortedMinus` directlyAwaits a} k env ContinuationEnd)
}

buildAcond
Expand All @@ -846,11 +861,11 @@ buildAcond var true false next =
BuildSchedule{
directlyAwaits = directlyAwaits true `sortedIntersection` directlyAwaits false,
trivial = False,
construct = \k cont -> Acond
construct = \k env cont -> Acond
(weaken k var)
(construct' true{directlyAwaits = directlyAwaits true `sortedMinus` directlyAwaits false} k ContinuationEnd)
(construct' false{directlyAwaits = directlyAwaits false `sortedMinus` directlyAwaits true} k ContinuationEnd)
(construct' next k cont)
(constructFull true{directlyAwaits = directlyAwaits true `sortedMinus` directlyAwaits false} k env ContinuationEnd)
(constructFull false{directlyAwaits = directlyAwaits false `sortedMinus` directlyAwaits true} k env ContinuationEnd)
(constructFull next k env cont)
}

buildAwhile
Expand All @@ -863,11 +878,11 @@ buildAwhile io step initial next =
BuildSchedule{
directlyAwaits = [], -- TODO: Compute this based on the use of the initial state and free variables of step.
trivial = False,
construct = \k cont -> Awhile
construct = \k env cont -> Awhile
io
(funConstruct step k)
(mapTupR (weaken k) initial)
(construct' next k cont)
(constructFull next k env cont)
}

buildFunLam
Expand All @@ -883,7 +898,7 @@ buildFunLam lhs body =
buildFunBody :: BuildSchedule kernel env -> BuildScheduleFun kernel env ()
buildFunBody body =
BuildScheduleFun{
funConstruct = \k -> Sbody $ construct' body k ContinuationEnd
funConstruct = \k -> Sbody $ constructFull body k (BuildEnv IdxSet.empty) ContinuationEnd
}

-- Assumes that the input arrays are sorted,
Expand Down

0 comments on commit 9d5d024

Please sign in to comment.