Skip to content

Commit

Permalink
Merge branch 'new-pipeline' of https://www.github.com/ivogabe/accelerate
Browse files Browse the repository at this point in the history
 into new-pipeline-david
  • Loading branch information
dpvanbalen committed May 24, 2024
2 parents 6149f49 + 4155e06 commit 1d2d2b5
Show file tree
Hide file tree
Showing 8 changed files with 477 additions and 762 deletions.
14 changes: 13 additions & 1 deletion src/Data/Array/Accelerate/AST/Environment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ module Data.Array.Accelerate.AST.Environment (
Skip(..), skipIdx, chainSkip, skipWeakenIdx, lhsSkip,

prjUpdate', prjReplace', update', updates', mapEnv,
(:>)(..), weakenId, weakenSucc, weakenSucc', weakenEmpty,
(:>)(..), weakenId, weakenSucc, weakenSucc', weakenEmpty, weakenReplace,
sink, (.>), sinkWithLHS, weakenWithLHS, substituteLHS,
varsGet, varsGetVal, stripWithLhs,weakenKeep) where

Expand Down Expand Up @@ -313,6 +313,11 @@ mapEnv g (Push env f) = Push (mapEnv g env) (g f)
--
newtype env :> env' = Weaken { (>:>) :: forall (t' :: Type). Idx env t' -> Idx env' t' } -- Weak or Weaken

-- Weaken is currently just a function. We could consider to partially defunctionalize this and define
-- it as a data type, which constructors for the following functions.
-- Note that we may then also need to fold a chain of weakenSucc to a SkipIdx, and change the internal
-- definition of SkipIdx to Int (like we also did for Idx).

weakenId :: env :> env
weakenId = Weaken id

Expand All @@ -330,6 +335,13 @@ weakenKeep (Weaken f) = Weaken $ \case
weakenEmpty :: () :> env'
weakenEmpty = Weaken $ \(VoidIdx x) -> x

weakenReplace :: forall env env' t. Idx env' t -> env :> env' -> (env, t) :> env'
weakenReplace other k = Weaken f
where
f :: forall s. Idx (env, t) s -> Idx env' s
f ZeroIdx = other
f (SuccIdx idx) = k >:> idx

sink :: forall env env' t. env :> env' -> (env, t) :> (env', t)
sink (Weaken f) = Weaken g
where
Expand Down
4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/AST/Partitioned.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import Data.Array.Accelerate.AST.Operation hiding (OperationAcc, OperationAfun)
import Prelude hiding ( take )
import Data.Bifunctor
import Data.Array.Accelerate.Trafo.Desugar (ArrayDescriptor(..))
import Data.Array.Accelerate.Trafo.Operation.Simplify (SimplifyOperation(..))
import Data.Array.Accelerate.Representation.Array (Array, Buffers, ArrayR (..))
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.Representation.Shape (ShapeR (..), shapeType, typeShape)
Expand Down Expand Up @@ -535,6 +536,9 @@ instance ShrinkArg (BackendClusterArg op) => SLVOperation (Clustered op) where
ShrunkOperation' cluster' args ->
ShrunkOperation (Clustered cluster' $ shrinkArgs ff' b) args

instance SimplifyOperation (Clustered op)
-- Default implementation, where detectCopy always returns []

-- instance SLVOperation (Cluster op) where
-- slvOperation cluster = -- Nothing
-- Just $ ShrinkOperation $ \sub args' args ->
Expand Down
3 changes: 1 addition & 2 deletions src/Data/Array/Accelerate/AST/Schedule/Uniform.hs
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,11 @@ reorder = uncurry await . go Just
trivialBinding :: Binding env t -> Bool
trivialBinding (NewSignal _) = True
trivialBinding (NewRef _) = True
trivialBinding (Alloc ShapeRz _ _) = True
trivialBinding (Alloc _ _ _) = True
trivialBinding (Use _ _ _) = True
trivialBinding (Unit _) = True
trivialBinding (RefRead _) = True
trivialBinding (Compute e) = expIsTrivial (const True) e
trivialBinding _ = False

-- If a schedule does not do blocking or slow operations, we say it's trivial
-- and don't need to spawn it as we do not gain much task parallelism from it.
Expand Down
5 changes: 5 additions & 0 deletions src/Data/Array/Accelerate/Pretty/Schedule/Uniform.hs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ prettyUniformSchedule env = \case
<> hardline <> hang 4 (" (" <+> prettyUniformScheduleFun env body)
<> hardline <> " )"
<> prettyNext next
Spawn (Effect (SignalAwait signals) body) next
-> annotate Statement "spawn await" <+> list (map (prettyIdx env) signals) <+> "{"
<> hardline <> indent 2 (prettyUniformSchedule env body)
<> hardline <> "}"
<> prettyNext next
Spawn body next
-> annotate Statement "spawn" <+> "{"
<> hardline <> indent 2 (prettyUniformSchedule env body)
Expand Down
28 changes: 14 additions & 14 deletions src/Data/Array/Accelerate/Trafo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ testWithObjective obj f
$ Sharing.convertAfunWith defaultOptions f

partitioned =
-- Operation.simplifyFun $
Operation.simplifyFun $
NewNewFusion.convertAfun obj operation

slvpartitioned =
-- Operation.simplifyFun $
Operation.simplifyFun $
Operation.stronglyLiveVariablesFun partitioned

schedule = convertScheduleFun @sched @kernel slvpartitioned
Expand Down Expand Up @@ -153,8 +153,8 @@ convertAccWith
-> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ())
convertAccWith config
= phase' "codegen" rnfSchedule convertSchedule
. phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables)
. phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccWith config defaultObjective)
. phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccWith config defaultObjective)
. phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "desugar" (Operation.simplify . desugar)
. phase "array-split-lets" LetSplit.convertAcc
Expand All @@ -169,8 +169,8 @@ convertAccBench
-> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ())
convertAccBench b
= phase' "codegen" rnfSchedule convertSchedule
. phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables)
. phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccBench b)
. phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccBench b)
. phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "desugar" (Operation.simplify . desugar)
. phase "array-split-lets" LetSplit.convertAcc
Expand All @@ -185,8 +185,8 @@ convertAfunBench
-> sched kernel () (Scheduled sched (DesugaredAfun (ArraysFunctionR f)))
convertAfunBench b
= phase' "codegen" rnfSchedule convertScheduleFun
. phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun)
. phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAccBenchF b)
. phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAccBenchF b)
. phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "desugar" (Operation.simplifyFun . desugarAfun)
. phase "array-split-lets" LetSplit.convertAfun
Expand All @@ -202,8 +202,8 @@ convertAccWithObj
-> sched kernel () (ScheduleOutput sched (DesugaredArrays (ArraysR arrs)) -> ())
convertAccWithObj obj
= phase' "codegen" rnfSchedule convertSchedule
. phase "partition-live-vars" ({-Operation.simplify . -} Operation.stronglyLiveVariables)
. phase "array-fusion" ({-Operation.simplify . -} NewNewFusion.convertAccWith defaultOptions obj)
. phase "partition-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "array-fusion" (Operation.simplify . NewNewFusion.convertAccWith defaultOptions obj)
. phase "operation-live-vars" (Operation.simplify . Operation.stronglyLiveVariables)
. phase "desugar" (Operation.simplify . desugar)
. phase "array-split-lets" LetSplit.convertAcc
Expand All @@ -230,8 +230,8 @@ convertAfunWith
convertAfunWith config
= (\s -> Debug.Trace.trace (Pretty.renderForTerminal (Pretty.prettySchedule s)) s)
. phase' "codegen" rnfSchedule convertScheduleFun
. phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun)
. phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAfunWith config defaultObjective)
. phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAfunWith config defaultObjective)
. phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "desugar" (Operation.simplifyFun . desugarAfun)
. phase "array-split-lets" LetSplit.convertAfun
Expand All @@ -246,8 +246,8 @@ convertAfunWithObj
-> sched kernel () (Scheduled sched (DesugaredAfun (ArraysFunctionR f)))
convertAfunWithObj obj
= phase' "codegen" rnfSchedule convertScheduleFun
. phase "partition-live-vars" ({- Operation.simplifyFun . -} Operation.stronglyLiveVariablesFun)
. phase "array-fusion" ({- Operation.simplifyFun . -} NewNewFusion.convertAfunWith defaultOptions obj)
. phase "partition-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "array-fusion" (Operation.simplifyFun . NewNewFusion.convertAfunWith defaultOptions obj)
. phase "operation-live-vars" (Operation.simplifyFun . Operation.stronglyLiveVariablesFun)
. phase "desugar" (Operation.simplifyFun . desugarAfun)
. phase "array-split-lets" LetSplit.convertAfun
Expand Down
17 changes: 13 additions & 4 deletions src/Data/Array/Accelerate/Trafo/Schedule/Partial.hs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ weakenSyncEnv (LeftHandSidePair l1 l2) env = weakenSyncEnv l1 $ weaken

-- Intermediate representation, used during the transformation

data Parallelism = Parallel | Sequential deriving Eq
data Parallelism
-- This let-binding is parallel: the binding is spawned in a separate task, and the body is directly executed.
= Parallel
-- This let-binding is sequential. The binding is executed before the body.
| Sequential
deriving Eq

data PrePartialScheduleFun (schedule :: (Type -> Type) -> Type -> Type -> Type) kernel env t where
Plam :: GLeftHandSide s env env'
Expand Down Expand Up @@ -571,17 +576,21 @@ buildLet parallelismHint lhs us bnd' body' available =
-- If the binding is trivial and doesn't wait on variables.
|| trivial bnd && IdxSet.null (directlyAwaits bnd)

thisLhsIndices = lhsIndices lhs

-- Make it sequential if:
sequential =
-- the early test already made it sequential,
sequentialEarlyTest
-- the body directly needs all declared variables,
|| IdxSet.fromVarList (lhsVars lhs) `IdxSet.isSubsetOf` directlyAwaits body
-- or the binding is trivial and doesn't use other variables than the
-- the body directly needs all declared variables (and there is at least one such variable),
|| not (IdxSet.isEmpty thisLhsIndices) && thisLhsIndices `IdxSet.isSubsetOf` directlyAwaits body
-- the binding is trivial and doesn't use other variables than the
-- body already waits on.
-- Note that the awaits set contains all free variables (minus the available variables)
-- if the program is trivial.
|| trivial bnd && directlyAwaits bnd `IdxSet.isSubsetOf` bodyAwaitsDropped
-- or the binding is trivial and the body directly awaits on the result of this expression.
|| trivial bnd && thisLhsIndices `IdxSet.overlaps` directlyAwaits body

parallelism = if sequential then Sequential else Parallel

Expand Down
Loading

0 comments on commit 1d2d2b5

Please sign in to comment.