From 1b8c6c096a1903405820b0b9a0b6b188e28fbaa6 Mon Sep 17 00:00:00 2001 From: Ivo Gabe de Wolff Date: Wed, 29 May 2024 15:23:23 +0200 Subject: [PATCH 1/3] Simplify default desugaring for fold to be sequential The old desugaring had two bugs: - The <= in the condition of the awhile loop should be > - The body of the awhile loop took unique ownership of the output buffer(s) via shared variable(s). This would break in the UniformSchedule Instead of fixing these bugs, I decided to change the desugaring to be sequential. Any backend that really cares about performance will have its own fold implementation, and this sequential implementation is still very useful during the development of a new backend. --- src/Data/Array/Accelerate/Trafo/Desugar.hs | 166 ++++----------------- 1 file changed, 33 insertions(+), 133 deletions(-) diff --git a/src/Data/Array/Accelerate/Trafo/Desugar.hs b/src/Data/Array/Accelerate/Trafo/Desugar.hs index 0b648d959..63490c863 100644 --- a/src/Data/Array/Accelerate/Trafo/Desugar.hs +++ b/src/Data/Array/Accelerate/Trafo/Desugar.hs @@ -207,23 +207,7 @@ class NFData' op => DesugarAcc (op :: Type -> Type) where -> Arg env (In (sh, Int) e) -> Arg env (Out sh e) -> OperationAcc op env () - mkFold f def input@(ArgArray _ repr@(ArrayR shr tp) _ _) output - -- Binding of the result of the first step - | DeclareVars lhs1 k1 value1 <- declareVars $ desugarArrayR repr - -- Binding in the iteration and condition function for the first step - , DeclareVars lhsSh kSh valueSh <- declareVars $ mapTupR GroundRscalar $ shapeType shr - , DeclareVars lhsBf kBf valueBf <- declareVars $ buffersR tp - = let - lhs2 = LeftHandSidePair lhsSh lhsBf - argTmp = ArgArray In (ArrayR shr tp) (valueSh kBf) (valueBf weakenId) - c = Alam lhs2 $ Abody $ case valueSh kBf of - TupRpair _ ix -> Compute $ mkBinary (PrimLtEq singleType) (paramsIn (TupRsingle scalarTypeInt) ix) (mkConstant (TupRsingle scalarTypeInt) 1) - _ -> error "Impossible pair" - g = Alam lhs2 $ Abody $ mkDefaultFoldStep2 (weaken (kBf .> kSh .> k1) f) argTmp (weaken (kBf .> kSh .> k1) output) - in - alet lhs1 (mkDefaultFoldStep1 f def input output) - $ alet (LeftHandSideWildcard $ desugarArrayR repr) (Awhile (shared $ desugarArrayR repr) c g $ value1 weakenId) - $ Return TupRunit + mkFold f def input@(ArgArray _ repr@(ArrayR shr tp) _ _) output = mkDefaultFoldSequential f def input output mkFoldSeg :: IntegralType i -> Arg env (Fun' (e -> e -> e)) @@ -1178,125 +1162,41 @@ mkIntersect shr x y mkIntersect' (ShapeRsnoc shr) (TupRpair s1 x1) (TupRpair s2 x2) = mkIntersect' shr s1 s2 `Pair` mkBinary (PrimMin singleType) (paramsIn' x1) (paramsIn' x2) mkIntersect' (ShapeRsnoc _) _ _ = error "Impossible pair" --- Default implementation for the first step of a fold. --- The output of the inner dimension is guaranteed to --- be a power of two. -mkDefaultFoldStep1 :: forall benv op sh e. DesugarAcc op => Arg benv (Fun' (e -> e -> e)) -> Maybe (Arg benv (Exp' e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Out sh e) -> OperationAcc op benv (DesugaredArrays (Array (sh, Int) e)) -mkDefaultFoldStep1 (ArgFun f) def argIn@(ArgArray _ (ArrayR shr tp) (sh `TupRpair` n) _) argOut - | DeclareVars lhsTmp kTmp valueTmp <- declareVars $ buffersR tp - = let - lhsN1 = LeftHandSideSingle $ GroundRscalar scalarTypeInt - kN1 = weakenSucc weakenId - - -- n-1 in case of fold1. - -- For fold, we have one additional element (the default value), and thus have (n+1)-1 - nMinus1 - | Just _ <- def = paramsIn (TupRsingle scalarType) n - | otherwise = mkBinary (PrimSub numType) (paramsIn (TupRsingle scalarType) n) (mkConstant (TupRsingle scalarTypeInt) 1) - - shBase' = weakenVars (weakenSucc weakenId) sh - shTmp' = shBase' `TupRpair` TupRsingle (Var (GroundRscalar scalarTypeInt) ZeroIdx) - - shTmp = weakenVars kTmp shTmp' - tmp = valueTmp weakenId - k = weakenSucc kTmp - argG = ArgFun $ mkDefaultFoldStep1Function (weakenArrayInstr k f) (weaken k <$> def) (weaken kTmp $ Var (GroundRscalar scalarTypeInt) ZeroIdx) (weaken k argIn) - argTmp = ArgArray Out (ArrayR shr tp) shTmp tmp - in - alet lhsN1 (Compute $ mkBinary (PrimBShiftL integralType) (mkConstant (TupRsingle scalarTypeInt) 1) $ mkLog2 nMinus1) - $ aletUnique lhsTmp (mkDefaultFoldAllocOrOutput (weaken kN1 argOut) $ groundToExpVar (shapeType shr) shTmp') - $ alet (LeftHandSideWildcard TupRunit) (mkGenerate argG argTmp) - $ Return (shTmp `TupRpair` tmp) - --- log_2(x) = 63 − clz(x) (for 64-bit integers) -mkLog2 :: OpenExp env benv Int -> OpenExp env benv Int -mkLog2 = mkBinary (PrimSub numType) (mkConstant (TupRsingle scalarTypeInt) 63) . PrimApp (PrimCountLeadingZeros integralType) - -mkDefaultFoldStep1Function :: forall benv sh e. HasCallStack => Fun benv (e -> e -> e) -> Maybe (Arg benv (Exp' e)) -> GroundVar benv Int -> Arg benv (In (sh, Int) e) -> Fun benv ((sh, Int) -> e) -mkDefaultFoldStep1Function f def n1' argIn@(ArgArray _ (ArrayR (ShapeRsnoc shr) tp) (sh' `TupRpair` n') buffers) - | DeclareVars lhsX kX valueX <- declareVars $ shapeType shr - = let - x = expVars $ valueX (weakenSucc weakenId) - y = Evar $ Var scalarTypeInt ZeroIdx - y2 = mkBinary (PrimMul numType) y (mkConstant (TupRsingle scalarTypeInt) 2) - n1 = paramIn scalarTypeInt n1' - - -- When a default or initial value is given, we just pretend that the input array is one - -- element larger, i.e., prefixed with that default value. - n - | Just _ <- def = mkBinary (PrimAdd numType) (paramsIn (TupRsingle scalarTypeInt) n') (mkConstant (TupRsingle scalarTypeInt) 1) - | otherwise = paramsIn (TupRsingle scalarTypeInt) n' - - index' :: OpenExp env' benv sh -> OpenExp env' benv Int -> OpenExp env' benv e - index' x' y' - | Just (ArgExp d) <- def = - Cond (mkBinary (PrimEq singleType) y' $ mkConstant (TupRsingle scalarTypeInt) 0) - (weakenE weakenEmpty d) - (index (ArrayR (ShapeRsnoc shr) tp) (sh' `TupRpair` n') buffers $ Pair x' $ mkBinary (PrimSub numType) y' $ mkConstant (TupRsingle scalarTypeInt) 1) - index' x' y' = - (index (ArrayR (ShapeRsnoc shr) tp) (sh' `TupRpair` n') buffers $ Pair x' y') - in - -- \(x, y) -> - Lam (lhsX `LeftHandSidePair` LeftHandSideSingle scalarTypeInt) - $ Body - -- if (y < n-n1) - $ Cond (mkBinary (PrimLt singleType) y $ mkBinary (PrimSub numType) n n1) - -- then {reduce y*2 and y*2+1} - (apply2 tp f (index' x y2) (index' x $ mkBinary (PrimAdd numType) y2 $ mkConstant (TupRsingle scalarTypeInt) 1)) - -- else {just copy the value from index (y+n-n1)} - (index' x $ mkBinary (PrimAdd numType) y $ mkBinary (PrimSub numType) n n1) - --- Halves the inner dimension of the array -mkDefaultFoldStep2 :: DesugarAcc op => Arg benv (Fun' (e -> e -> e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Out sh e) -> OperationAcc op benv (DesugaredArrays (Array (sh, Int) e)) -mkDefaultFoldStep2 (ArgFun f) argIn@(ArgArray _ (ArrayR shr@(ShapeRsnoc shr') tp) (TupRpair sh n) input) argOut - | DeclareVars lhsSh kSh valueSh <- declareVars $ TupRsingle $ GroundRscalar scalarTypeInt - , DeclareVars lhsTmp kTmp valueTmp <- declareVars $ buffersR tp - = let - shBase' = weakenVars kSh sh - shTmp' = shBase' `TupRpair` valueSh weakenId +mkDefaultFoldSequential :: forall benv op sh e. DesugarAcc op => Arg benv (Fun' (e -> e -> e)) -> Maybe (Arg benv (Exp' e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Out sh e) -> OperationAcc op benv () +mkDefaultFoldSequential op def argIn argOut = mkGenerate (mkDefaultFoldFunction op def argIn) argOut - shBase = weakenVars kTmp shBase' - shIn = shBase `TupRpair` weakenVars (kTmp .> kSh) n - shTmp = shBase `TupRpair` valueSh kTmp +mkDefaultFoldFunction :: Arg benv (Fun' (e -> e -> e)) -> Maybe (Arg benv (Exp' e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Fun' (sh -> e)) +mkDefaultFoldFunction (ArgFun op) def (ArgArray _ (ArrayR (ShapeRsnoc shr) tp) (sh `TupRpair` n) buffers) + | DeclareVars lhsIdx k1 valueIdx <- declareVars $ shapeType shr + , DeclareVars lhsVal k2 valueVal <- declareVars tp = + let + initial = case def of + Nothing -> + Pair + (mkConstant (TupRsingle scalarTypeInt) 1) + (index (ArrayR (ShapeRsnoc shr) tp) (sh `TupRpair` n) buffers (expVars (valueIdx weakenId) `Pair` mkConstant (TupRsingle scalarTypeInt) 0)) + Just (ArgExp d) -> + Pair + (mkConstant (TupRsingle scalarTypeInt) 0) + (weakenE weakenEmpty d) - temp = valueTmp weakenId - argG = weaken (kTmp .> kSh) $ ArgFun $ mkDefaultFoldStep2Function f argIn - argTmp = ArgArray Out (ArrayR shr tp) shTmp temp + lhs = LeftHandSidePair (LeftHandSideSingle scalarTypeInt) lhsVal + -- \(idx, accum) + step = Lam lhs $ Body $ Pair + -- (idx + 1 + (mkBinary (PrimAdd numType) (Evar $ Var scalarTypeInt $ k2 >:> ZeroIdx) (mkConstant (TupRsingle scalarTypeInt) 1)) + -- , op accum (input !! idx) + $ apply2 tp op (expVars $ valueVal weakenId) + $ index (ArrayR (ShapeRsnoc shr) tp) (sh `TupRpair` n) buffers + $ expVars (valueIdx $ weakenSucc k2) `Pair` Evar (Var scalarTypeInt $ k2 >:> ZeroIdx) + + condition = + Lam (LeftHandSidePair (LeftHandSideSingle scalarTypeInt) (LeftHandSideWildcard tp)) + $ Body $ mkBinary (PrimLt singleType) (Evar $ Var scalarTypeInt ZeroIdx) (paramsIn (TupRsingle scalarType) n) in - alet lhsSh (Compute $ mkBinary (PrimRem integralType) (paramsIn (TupRsingle scalarTypeInt) n) (Const scalarTypeInt 2)) - $ aletUnique lhsTmp (mkDefaultFoldAllocOrOutput (weaken kSh argOut) $ groundToExpVar (shapeType shr) shTmp') - $ alet (LeftHandSideWildcard TupRunit) (mkGenerate argG argTmp) - $ Return (shTmp `TupRpair` temp) - --- Allocates a new intermediate array or returns the output array. --- If the inner dimension is 1, returns the output array, as we are in the last iteration. --- Otherwise, allocates a new intermediate array. --- -mkDefaultFoldAllocOrOutput :: Arg benv (Out sh e) -> ExpVars benv (sh, Int) -> OperationAcc op benv (Buffers e) -mkDefaultFoldAllocOrOutput (ArgArray _ (ArrayR shr e) _ output) sh@(TupRpair _ y) - = Alet (LeftHandSideSingle $ GroundRscalar scalarType) (TupRsingle Shared) (Compute $ mkBinary (PrimEq singleType) (paramsIn' y) $ mkConstant (TupRsingle scalarTypeInt) 1) - $ Acond (Var scalarTypeWord8 ZeroIdx) - (Return $ weakenVars (weakenSucc weakenId) output) - (desugarAlloc (ArrayR (ShapeRsnoc shr) e) $ weakenVars (weakenSucc weakenId) sh) - --- \(x, y) -> f (a !! (x, 2*y)) (a !! (x, 2*y+1)) --- ==> \(x, y) -> let z = toIndex (x, 2*y) in f (a ! z) (a ! z+1) -mkDefaultFoldStep2Function :: forall benv sh e. HasCallStack => Fun benv (e -> e -> e) -> Arg benv (In (sh, Int) e) -> Fun benv ((sh, Int) -> e) -mkDefaultFoldStep2Function f (ArgArray _ (ArrayR (ShapeRsnoc shr) tp) sh input) - | DeclareVars lhsX kX valueX <- declareVars $ shapeType shr - , DeclareVars lhsY kY valueY <- declareVars $ TupRsingle scalarTypeInt - -- \(x, y) -> - = Lam (lhsX `LeftHandSidePair` lhsY) - $ Body - -- let z = toIndex (x, 2*y) - $ Let (LeftHandSideSingle scalarTypeInt) (ToIndex (ShapeRsnoc shr) (paramsIn (shapeType $ ShapeRsnoc shr) sh) (expVars (valueX kY) `Pair` mkBinary (PrimMul numType) (expVars $ valueY weakenId) (mkConstant (TupRsingle scalarTypeInt) 2))) - -- f - $ apply2 tp f - -- (a ! z) - (linearIndex tp input $ {- z -} Var scalarTypeInt ZeroIdx) - -- (let w = z + 1 in a ! w) - (Let (LeftHandSideSingle scalarTypeInt) (mkBinary (PrimAdd numType) (Evar $ Var scalarTypeInt ZeroIdx) (mkConstant (TupRsingle scalarTypeInt) 1)) - $ linearIndex tp input $ Var scalarTypeInt ZeroIdx) + ArgFun $ Lam lhsIdx $ Body + $ Let lhs (While condition step initial) + $ expVars $ valueVal weakenId -- In case of a scan with a default value, prepends the initial value before the other elements -- The default value is placed as the first value in case of a left-to-right scan, or as the From f14571eb4e167ec0ed5580526a181a013a899e28 Mon Sep 17 00:00:00 2001 From: Ivo Gabe de Wolff Date: Wed, 29 May 2024 15:24:03 +0200 Subject: [PATCH 2/3] Fixes in schedule construction and exp pretty printer --- src/Data/Array/Accelerate/Pretty/Exp.hs | 3 ++- src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Future.hs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Accelerate/Pretty/Exp.hs b/src/Data/Array/Accelerate/Pretty/Exp.hs index 4f2db1667..bbc552f0b 100644 --- a/src/Data/Array/Accelerate/Pretty/Exp.hs +++ b/src/Data/Array/Accelerate/Pretty/Exp.hs @@ -135,7 +135,8 @@ prettyPreOpenExp ctx prettyArrayInstr env exp = -- single = parensIf (needsParens ctx (Operator "?:" Infix N 0)) $ sep [ p', pretty '?', t', pretty ':', e' ] - multi = hang 3 + multi = parensIf (ctxPrecedence ctx > 0) + $ hang 3 $ vsep [ if_ <+> p' , hang shiftwidth (sep [ then_, t' ]) , hang shiftwidth (sep [ else_, e' ]) ] diff --git a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Future.hs b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Future.hs index 1f664a039..2cb65ac83 100644 --- a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Future.hs +++ b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Future.hs @@ -707,7 +707,7 @@ loopFuture resolved (FutureBuffer tp ref (Move readLockSignal) (Just (Move write $ Just $ Borrow (Just signalW) resolverR | otherwise -> internalError "input or output impossible" } -loopFuture resolved (FutureBuffer tp ref (Lock readLockSignal readLockResolver) (Just (Lock writeLockSignal writeLockResolver))) = undefined +loopFuture resolved (FutureBuffer tp ref (Lock readLockSignal readLockResolver) (Just (Lock writeLockSignal writeLockResolver))) = -- A borrowed writable buffer -- We must add two signals (and accompanying signal resolvers) to the state -- to synchronize read and write access. Furthermore we need to declare two From 00c91e0d92094cdd1eefc7d0f6462c5e95b3b303 Mon Sep 17 00:00:00 2001 From: Ivo Gabe de Wolff Date: Wed, 29 May 2024 20:26:51 +0200 Subject: [PATCH 3/3] Detect signals that are resolved at the same time Eliminate those signals by only waiting on the signal with the higher index. --- .../Trafo/Schedule/Uniform/Simplify.hs | 176 ++++++++++++++++-- 1 file changed, 156 insertions(+), 20 deletions(-) diff --git a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs index db74bdae8..665240968 100644 --- a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs @@ -129,7 +129,9 @@ constructFull schedule k env postponed cont | null $ directlyAwaits schedule = construct schedule k env postponed cont | signals' <- -- Don't wait on already resolved signals - filter (\idx -> not (isResolved idx env)) + sortedDedup + $ sort + $ filter (\idx -> not (isResolved idx env)) $ map (weaken k) $ directlyAwaits schedule , env' <- markResolved signals' env = @@ -278,13 +280,13 @@ markResolved :: [Idx env Signal] -> BuildEnv env -> BuildEnv env markResolved [] env = env markResolved signals (BPush env info) | ZeroIdx : signals' <- signals - = BPush (markResolved (map forceWeaken signals') env) IResolved + = BPush (markResolved (map unSucc signals') env) IResolved | otherwise - = BPush (markResolved (map forceWeaken signals) env) info + = BPush (markResolved (map unSucc signals) env) info where - forceWeaken :: Idx (env, t) s -> Idx env s - forceWeaken ZeroIdx = internalError "markResolved: input was not sorted or contains duplicates" - forceWeaken (SuccIdx idx) = idx + unSucc :: Idx (env, t) s -> Idx env s + unSucc ZeroIdx = internalError "markResolved: input was not sorted or contains duplicates" + unSucc (SuccIdx idx) = idx markResolved (s:_) BEmpty = case s of {} isResolved :: Idx env Signal -> BuildEnv env -> Bool @@ -386,6 +388,35 @@ buildLet lhs binding body (if shouldAwait then nothingPostponed else weaken' (weakenWithLHS lhs') postponed) $ weaken' (weakenWithLHS lhs') cont +buildLetNewSignal :: String -> [Idx env SignalResolver] -> BuildSchedule kernel ((env, Signal), SignalResolver) -> BuildSchedule kernel env +buildLetNewSignal comment resolvers body = + -- NewSignal is trivial + BuildSchedule{ + directlyAwaits = map (fromMaybe (internalError "Illegal schedule: deadlock") . strengthenWithLHS lhs) $ directlyAwaits body, + finallyResolves = mapMaybe (strengthenWithLHS lhs) $ finallyResolves body, + trivial = trivial body, + construct = \k env postponed cont -> if + | otherSignal : _ <- mapMaybe (\idx -> k >:> idx `findSignal` env) resolvers + , k' <- sink $ weakenReplace otherSignal k -> + -- Remove the index for the signal. + -- Replace all occurrences of that signal with 'otherSignal', + -- as their resolvers are resolved at the same time. + Alet lhsResolver (NewSignal comment) + $ construct body k' + (buildEnvExtend lhsResolver (NewSignal comment) env) + (weaken' (weakenSucc weakenId) postponed) + $ weaken' (weakenSucc weakenId) cont + | k' <- sink $ sink k -> + Alet lhs (NewSignal comment) + $ construct body k' + (buildEnvExtend lhs (NewSignal comment) env) + (weaken' (weakenSucc $ weakenSucc weakenId) postponed) + $ weaken' (weakenSucc $ weakenSucc weakenId) cont + } + where + lhs = LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle BaseRsignalResolver + lhsResolver = LeftHandSideWildcard (TupRsingle BaseRsignal) `LeftHandSidePair` LeftHandSideSingle BaseRsignalResolver + buildEnvExtend :: BLeftHandSide t env1 env2 -> Binding env1 t -> BuildEnv env1 -> BuildEnv env2 buildEnvExtend (LeftHandSidePair (LeftHandSideSingle _) (LeftHandSideSingle _)) (NewSignal _) env = env `BPush` INone `BPush` IResolvesNext @@ -420,7 +451,7 @@ buildEffect (SignalResolve resolvers) next = construct = \k env postponed cont -> let resolvers'' = map (weaken k) resolvers' - signals = mapMaybe (\r -> findSignal r env) resolvers'' + signals = sort $ mapMaybe (\r -> findSignal r env) resolvers'' env' = markResolved signals env in constructFull next k env' (resolveSignalsInPostponed signals resolvers'' postponed) cont @@ -578,6 +609,16 @@ mergeDedup as@(a:as') bs@(b:bs') mergeDedup as [] = as mergeDedup [] bs = bs +sortedDedup :: Eq a => [a] -> [a] +sortedDedup = \case + [] -> [] + a : as -> go a as + where + go x (y:ys) + | x == y = go x ys + | otherwise = x : go y ys + go x [] = [x] + -- Constructs the intersection of two lists, -- assuming they are sorted and have no duplicates. sortedIntersection :: Ord a => [a] -> [a] -> [a] @@ -605,18 +646,113 @@ simplify f = funConstruct (rebuildFun f) weakenId BEmpty rebuildFun :: UniformScheduleFun kernel env t -> BuildScheduleFun kernel env t rebuildFun (Slam lhs f) = buildFunLam lhs $ rebuildFun f -rebuildFun (Sbody body) = buildFunBody $ rebuild body +rebuildFun (Sbody body) = buildFunBody $ snd $ rebuild body -rebuild :: UniformSchedule kernel env -> BuildSchedule kernel env +rebuild :: UniformSchedule kernel env -> (SignalAnalysis env, BuildSchedule kernel env) rebuild = \case - Return -> buildReturn - Alet lhs bnd body -> - buildLet lhs bnd $ rebuild body - Effect eff next -> - buildEffect eff $ rebuild next - Acond var true false next -> - buildAcond var (rebuild true) (rebuild false) (rebuild next) - Awhile io f input next -> - buildAwhile io (rebuildFun f) input (rebuild next) - Spawn a b -> - buildSpawn (rebuild a) (rebuild b) + Return -> (SEmpty, buildReturn) + Alet lhs bnd body + | (analysis, body') <- rebuild body -> + ( analysisDrop lhs analysis + , rebuildLet lhs bnd analysis body' + ) + Effect eff next + | (analysis, next') <- rebuild next -> + ( analyseEffect eff `analysisJoin` analysis + , buildEffect eff next' + ) + Acond var true false next + | (aTrue, true') <- rebuild true + , (aFalse, false') <- rebuild false + , (aNext, next') <- rebuild next -> + ( analysisMeet aTrue aFalse `analysisJoin` aNext + , buildAcond var true' false' next' + ) + Awhile io f input next + | (analysis, next') <- rebuild next -> + ( analysis + , buildAwhile io (rebuildFun f) input next' + ) + Spawn term1 term2 + | (analysis1, term1') <- rebuild term1 + , (analysis2, term2') <- rebuild term2 -> + ( analysisJoin analysis1 analysis2 + , buildSpawn term1' term2' + ) + +rebuildLet + :: BLeftHandSide t env env' + -> Binding env t + -> SignalAnalysis env' + -> BuildSchedule kernel env' + -> BuildSchedule kernel env +rebuildLet (LeftHandSidePair LeftHandSideSingle{} LeftHandSideSingle{}) (NewSignal comment) (SPush _ (SIResolvedWith resolvers)) body = buildLetNewSignal comment (map unSucc resolvers) body + where + unSucc :: Idx (env, Signal) SignalResolver -> Idx env SignalResolver + unSucc (SuccIdx idx) = idx +rebuildLet lhs bnd _ body = buildLet lhs bnd body + +-- Signal analysis +data SignalAnalysis env where + SEmpty :: SignalAnalysis env + SPush :: SignalAnalysis env -> SignalInfo env t -> SignalAnalysis (env, t) + +spush :: SignalAnalysis env -> SignalInfo env t -> SignalAnalysis (env, t) +spush SEmpty SINone = SEmpty +spush env info = SPush env info + +data SignalInfo env t where + -- This SignalResolver is resolved at the same time as the given list of SignalResolvers. + SIResolvedWith + :: [Idx env SignalResolver] + -> SignalInfo env SignalResolver + + SINone + :: SignalInfo env t + +analysisDrop :: LeftHandSide s t env env' -> SignalAnalysis env' -> SignalAnalysis env +analysisDrop _ SEmpty = SEmpty +analysisDrop LeftHandSideWildcard{} env = env +analysisDrop LeftHandSideSingle{} (SPush env _) = env +analysisDrop (LeftHandSidePair lhs1 lhs2) env = analysisDrop lhs1 $ analysisDrop lhs2 env + +-- Use this when two terms are both executed, for instance in a spawn +analysisJoin :: SignalAnalysis env -> SignalAnalysis env -> SignalAnalysis env +analysisJoin SEmpty env = env +analysisJoin env SEmpty = env +analysisJoin (SPush as a) (SPush bs b) = analysisJoin as bs `SPush` signalInfoJoin a b + +signalInfoJoin :: SignalInfo env t -> SignalInfo env t -> SignalInfo env t +signalInfoJoin SINone info = info +signalInfoJoin info SINone = info +signalInfoJoin (SIResolvedWith as) (SIResolvedWith bs) = SIResolvedWith $ as `mergeDedup` bs + +-- Use this when only one of the two terms is executed, for instance in an if-then-else +analysisMeet :: SignalAnalysis env -> SignalAnalysis env -> SignalAnalysis env +analysisMeet SEmpty _ = SEmpty +analysisMeet _ SEmpty = SEmpty +analysisMeet (SPush as a) (SPush bs b) = analysisMeet as bs `SPush` signalInfoMeet a b + +signalInfoMeet :: SignalInfo env t -> SignalInfo env t -> SignalInfo env t +signalInfoMeet SINone _ = SINone +signalInfoMeet _ SINone = SINone +signalInfoMeet (SIResolvedWith as) (SIResolvedWith bs) = SIResolvedWith $ as `sortedIntersection` bs + +analyseEffect :: Effect kernel env -> SignalAnalysis env +analyseEffect (SignalResolve resolvers) = analyseSignalResolve resolvers +analyseEffect _ = SEmpty + +analyseSignalResolve :: [Idx env SignalResolver] -> SignalAnalysis env +analyseSignalResolve = const SEmpty -- go . sort + where + -- input is sorted from low indices to high indices + go :: [Idx env SignalResolver] -> SignalAnalysis env + go [] = SEmpty + go [_] = SEmpty + go (ZeroIdx : ids) = go ids' `SPush` SIResolvedWith ids' + where ids' = map unSucc ids + go ids@(SuccIdx _ : _) = go (map unSucc ids) `spush` SINone + + unSucc :: Idx (env, s) t -> Idx env t + unSucc (SuccIdx idx) = idx + unSucc ZeroIdx = internalError "Expected non-zero index. Is the list of indices sorted and unique?"