Skip to content

Commit

Permalink
Re-enable extra simplfications in Trafo, fix schedule bug, specialize…
Browse files Browse the repository at this point in the history
… pretty printing of await-after-spawn
  • Loading branch information
ivogabe committed May 23, 2024
1 parent ba6bb95 commit 4155e06
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/AST/Partitioned.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import Data.Array.Accelerate.AST.Operation hiding (OperationAcc, OperationAfun)
import Prelude hiding ( take )
import Data.Bifunctor
import Data.Array.Accelerate.Trafo.Desugar (ArrayDescriptor(..))
import Data.Array.Accelerate.Trafo.Operation.Simplify (SimplifyOperation(..))
import Data.Array.Accelerate.Representation.Array (Array, Buffers, ArrayR (..))
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.Representation.Shape (ShapeR (..), shapeType, typeShape)
Expand Down Expand Up @@ -535,6 +536,9 @@ instance ShrinkArg (BackendClusterArg op) => SLVOperation (Clustered op) where
ShrunkOperation' cluster' args ->
ShrunkOperation (Clustered cluster' $ shrinkArgs ff' b) args

instance SimplifyOperation (Clustered op)
-- Default implementation, where detectCopy always returns []

-- instance SLVOperation (Cluster op) where
-- slvOperation cluster = -- Nothing
-- Just $ ShrinkOperation $ \sub args' args ->
Expand Down
5 changes: 5 additions & 0 deletions src/Data/Array/Accelerate/Pretty/Schedule/Uniform.hs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ prettyUniformSchedule env = \case
<> hardline <> hang 4 (" (" <+> prettyUniformScheduleFun env body)
<> hardline <> " )"
<> prettyNext next
Spawn (Effect (SignalAwait signals) body) next
-> annotate Statement "spawn await" <+> list (map (prettyIdx env) signals) <+> "{"
<> hardline <> indent 2 (prettyUniformSchedule env body)
<> hardline <> "}"
<> prettyNext next
Spawn body next
-> annotate Statement "spawn" <+> "{"
<> hardline <> indent 2 (prettyUniformSchedule env body)
Expand Down
28 changes: 14 additions & 14 deletions src/Data/Array/Accelerate/Trafo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ testWithObjective obj f
$ Sharing.convertAfunWith defaultOptions f

partitioned =
-- Operation.simplifyFun $
Operation.simplifyFun $
NewNewFusion.convertAfun obj operation

slvpartitioned =
-- Operation.simplifyFun $
Operation.simplifyFun $
Operation.stronglyLiveVariablesFun partitioned

schedule = convertScheduleFun @sched @kernel slvpartitioned
Expand Down Expand Up @@ -153,8 +153,8 @@ convertAccWith
-> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ())
convertAccWith config
= phase' "codegen" rnfSchedule convertSchedule
. phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables)
. phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccWith config defaultObjective)
. phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccWith config defaultObjective)
. phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "desugar" (Operation.simplify . desugar)
. phase "array-split-lets" LetSplit.convertAcc
Expand All @@ -169,8 +169,8 @@ convertAccBench
-> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ())
convertAccBench b
= phase' "codegen" rnfSchedule convertSchedule
. phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables)
. phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccBench b)
. phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccBench b)
. phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "desugar" (Operation.simplify . desugar)
. phase "array-split-lets" LetSplit.convertAcc
Expand All @@ -185,8 +185,8 @@ convertAfunBench
-> sched kernel () (Scheduled sched (DesugaredAfun (ArraysFunctionR f)))
convertAfunBench b
= phase' "codegen" rnfSchedule convertScheduleFun
. phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun)
. phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAccBenchF b)
. phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAccBenchF b)
. phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "desugar" (Operation.simplifyFun . desugarAfun)
. phase "array-split-lets" LetSplit.convertAfun
Expand All @@ -202,8 +202,8 @@ convertAccWithObj
-> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ())
convertAccWithObj obj
= phase' "codegen" rnfSchedule convertSchedule
. phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables)
. phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccWith defaultOptions obj)
. phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccWith defaultOptions obj)
. phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "desugar" (Operation.simplify . desugar)
. phase "array-split-lets" LetSplit.convertAcc
Expand All @@ -230,8 +230,8 @@ convertAfunWith
convertAfunWith config
= (\s -> Debug.Trace.trace (Pretty.renderForTerminal (Pretty.prettySchedule s)) s)
. phase' "codegen" rnfSchedule convertScheduleFun
. phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun)
. phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAfunWith config defaultObjective)
. phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAfunWith config defaultObjective)
. phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "desugar" (Operation.simplifyFun . desugarAfun)
. phase "array-split-lets" LetSplit.convertAfun
Expand All @@ -246,8 +246,8 @@ convertAfunWithObj
-> sched kernel () (Scheduled sched (DesugaredAfun (ArraysFunctionR f)))
convertAfunWithObj obj
= phase' "codegen" rnfSchedule convertScheduleFun
. phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun)
. phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAfunWith defaultOptions obj)
. phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAfunWith defaultOptions obj)
. phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "desugar" (Operation.simplifyFun . desugarAfun)
. phase "array-split-lets" LetSplit.convertAfun
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ findDependingSpawn (Postponed spawns resolvers) nextDirectlyAwaits = case go spa
= Just (x, xs)
| otherwise = case go xs of
Nothing -> Nothing
Just (y, ys) -> Just (y, y:ys)
Just (y, ys) -> Just (y, x:ys)
go [] = Nothing

data BuildEnv env where
Expand Down

0 comments on commit 4155e06

Please sign in to comment.