From 4155e06c24d90db5010872b1d8b540acf9837d35 Mon Sep 17 00:00:00 2001 From: Ivo Gabe de Wolff Date: Thu, 23 May 2024 17:53:27 +0200 Subject: [PATCH] Re-enable extra simplfications in Trafo, fix schedule bug, specialize pretty printing of await-after-spawn --- src/Data/Array/Accelerate/AST/Partitioned.hs | 4 +++ .../Accelerate/Pretty/Schedule/Uniform.hs | 5 ++++ src/Data/Array/Accelerate/Trafo.hs | 28 +++++++++---------- .../Trafo/Schedule/Uniform/Simplify.hs | 2 +- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/Data/Array/Accelerate/AST/Partitioned.hs b/src/Data/Array/Accelerate/AST/Partitioned.hs index 2f79dd3e3..05de7c414 100644 --- a/src/Data/Array/Accelerate/AST/Partitioned.hs +++ b/src/Data/Array/Accelerate/AST/Partitioned.hs @@ -44,6 +44,7 @@ import Data.Array.Accelerate.AST.Operation hiding (OperationAcc, OperationAfun) import Prelude hiding ( take ) import Data.Bifunctor import Data.Array.Accelerate.Trafo.Desugar (ArrayDescriptor(..)) +import Data.Array.Accelerate.Trafo.Operation.Simplify (SimplifyOperation(..)) import Data.Array.Accelerate.Representation.Array (Array, Buffers, ArrayR (..)) import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.Representation.Shape (ShapeR (..), shapeType, typeShape) @@ -535,6 +536,9 @@ instance ShrinkArg (BackendClusterArg op) => SLVOperation (Clustered op) where ShrunkOperation' cluster' args -> ShrunkOperation (Clustered cluster' $ shrinkArgs ff' b) args +instance SimplifyOperation (Clustered op) + -- Default implementation, where detectCopy always returns [] + -- instance SLVOperation (Cluster op) where -- slvOperation cluster = -- Nothing -- Just $ ShrinkOperation $ \sub args' args -> diff --git a/src/Data/Array/Accelerate/Pretty/Schedule/Uniform.hs b/src/Data/Array/Accelerate/Pretty/Schedule/Uniform.hs index 08b5679c5..2f9c259ff 100644 --- a/src/Data/Array/Accelerate/Pretty/Schedule/Uniform.hs +++ b/src/Data/Array/Accelerate/Pretty/Schedule/Uniform.hs @@ -152,6 +152,11 @@ prettyUniformSchedule env = \case <> hardline <> hang 4 (" (" <+> prettyUniformScheduleFun env body) <> hardline <> " )" <> prettyNext next + Spawn (Effect (SignalAwait signals) body) next + -> annotate Statement "spawn await" <+> list (map (prettyIdx env) signals) <+> "{" + <> hardline <> indent 2 (prettyUniformSchedule env body) + <> hardline <> "}" + <> prettyNext next Spawn body next -> annotate Statement "spawn" <+> "{" <> hardline <> indent 2 (prettyUniformSchedule env body) diff --git a/src/Data/Array/Accelerate/Trafo.hs b/src/Data/Array/Accelerate/Trafo.hs index 81f2c76b3..090680838 100644 --- a/src/Data/Array/Accelerate/Trafo.hs +++ b/src/Data/Array/Accelerate/Trafo.hs @@ -119,11 +119,11 @@ testWithObjective obj f $ Sharing.convertAfunWith defaultOptions f partitioned = - -- Operation.simplifyFun $ + Operation.simplifyFun $ NewNewFusion.convertAfun obj operation slvpartitioned = - -- Operation.simplifyFun $ + Operation.simplifyFun $ Operation.stronglyLiveVariablesFun partitioned schedule = convertScheduleFun @sched @kernel slvpartitioned @@ -153,8 +153,8 @@ convertAccWith -> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ()) convertAccWith config = phase' "codegen" rnfSchedule convertSchedule - . phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables) - . phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccWith config defaultObjective) + . phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables) + . phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccWith config defaultObjective) . phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables) . phase "desugar" (Operation.simplify . desugar) . phase "array-split-lets" LetSplit.convertAcc @@ -169,8 +169,8 @@ convertAccBench -> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ()) convertAccBench b = phase' "codegen" rnfSchedule convertSchedule - . phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables) - . phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccBench b) + . phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables) + . phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccBench b) . phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables) . phase "desugar" (Operation.simplify . desugar) . phase "array-split-lets" LetSplit.convertAcc @@ -185,8 +185,8 @@ convertAfunBench -> sched kernel () (Scheduled sched (DesugaredAfun (ArraysFunctionR f))) convertAfunBench b = phase' "codegen" rnfSchedule convertScheduleFun - . phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun) - . phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAccBenchF b) + . phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun) + . phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAccBenchF b) . phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun) . phase "desugar" (Operation.simplifyFun . desugarAfun) . phase "array-split-lets" LetSplit.convertAfun @@ -202,8 +202,8 @@ convertAccWithObj -> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ()) convertAccWithObj obj = phase' "codegen" rnfSchedule convertSchedule - . phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables) - . phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccWith defaultOptions obj) + . phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables) + . phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccWith defaultOptions obj) . phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables) . phase "desugar" (Operation.simplify . desugar) . phase "array-split-lets" LetSplit.convertAcc @@ -230,8 +230,8 @@ convertAfunWith convertAfunWith config = (\s -> Debug.Trace.trace (Pretty.renderForTerminal (Pretty.prettySchedule s)) s) . phase' "codegen" rnfSchedule convertScheduleFun - . phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun) - . phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAfunWith config defaultObjective) + . phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun) + . phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAfunWith config defaultObjective) . phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun) . phase "desugar" (Operation.simplifyFun . desugarAfun) . phase "array-split-lets" LetSplit.convertAfun @@ -246,8 +246,8 @@ convertAfunWithObj -> sched kernel () (Scheduled sched (DesugaredAfun (ArraysFunctionR f))) convertAfunWithObj obj = phase' "codegen" rnfSchedule convertScheduleFun - . phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun) - . phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAfunWith defaultOptions obj) + . phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun) + . phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAfunWith defaultOptions obj) . phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun) . phase "desugar" (Operation.simplifyFun . desugarAfun) . phase "array-split-lets" LetSplit.convertAfun diff --git a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs index bd06085d6..a08e80988 100644 --- a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Simplify.hs @@ -204,7 +204,7 @@ findDependingSpawn (Postponed spawns resolvers) nextDirectlyAwaits = case go spa = Just (x, xs) | otherwise = case go xs of Nothing -> Nothing - Just (y, ys) -> Just (y, y:ys) + Just (y, ys) -> Just (y, x:ys) go [] = Nothing data BuildEnv env where